Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- import socket
- import socketserver
- import threading
- import time
- from queue import Queue, Empty
- import signal
- class ThreadingTcpServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
- pass
- class FooServerConnection(object):
- """A class for making a tcp connection to a foo server"""
- def __init__(self, ip_address, port):
- self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._sock.connect((ip_address, port))
- def send(self, data):
- self._sock.send(data)
- def recv(self, buffer_size=1024):
- return self._sock.recv(buffer_size)
- def read_message(sock, term_char='\n', buffer_size=1024, timeout=1):
- """Read a full message from a socket, terminated by term_char"""
- # Add a timeout in case a client or sends an incorrectly formatted message or stops midway through sending
- timeout = time.time() + timeout
- mess = "".encode()
- term_char = term_char.encode()
- while (not mess.endswith(term_char)) and (time.time() < timeout):
- mess += sock.recv(buffer_size)
- return mess
- class FooClientRequestHandler(socketserver.BaseRequestHandler):
- """Proxy class which forwards messages between a foo server and a foo client"""
- proxy = None # Must be set to a Proxy object before class can be used
- # TODO find a neater way to have the request handler send message to the proxy for storing stats
- def handle(self):
- # print("New connection : ", id(self))
- # Open connection to server
- fooServer = FooServerConnection(self.proxy.foo_server_ip_address, self.proxy.foo_server_port)
- # Repeat Until no message
- while True:
- # Read Message
- message = read_message(self.request)
- if not message:
- # print("No Message")
- break
- self.proxy.store_message(message)
- # Forward Message to Server
- fooServer.send(message)
- # Read Response
- response = read_message(fooServer)
- if not response:
- # print("No Response")
- break
- self.proxy.store_message(response)
- # Return Response
- self.request.sendall(response)
- return
- def get_from_queue(q):
- """Get all items in a Queue and return them as a list"""
- r = []
- assert type(q) == Queue
- while q.not_empty:
- try:
- r.append(q.get(block=False))
- except Empty:
- break
- return r
- class Proxy(object):
- def __init__(self, ip_address='127.0.0.1', client_port=8002, server_port=8001):
- self.foo_server_ip_address = ip_address
- self.foo_server_port = server_port
- FooClientRequestHandler.proxy = self
- self._sockserver = ThreadingTcpServer((ip_address, client_port), FooClientRequestHandler)
- self._serverthread = None
- self._statsthread = None
- self._shutdown = threading.Event()
- # Queues for sharing messages between the connection threads and the stats thread
- self._reqs = Queue()
- self._acks = Queue()
- self._naks = Queue()
- self._stats_lock = threading.Lock()
- self._total_reqs = 0
- self._total_acks = 0
- self._total_naks = 0
- self._reqs_10s = []
- self._acks_10s = []
- self._naks_10s = []
- self._reqs_1s = []
- self._acks_naks_1s = []
- def store_message(self, message):
- """Add a message to the appropriate Queue"""
- t = time.time()
- message = message.decode()
- if message.startswith("REQ"):
- self._reqs.put(t)
- self._total_reqs += 1
- elif message.startswith("ACK"):
- self._acks.put(t)
- self._total_acks += 1
- elif message.startswith("NAK"):
- self._naks.put(t)
- self._total_naks += 1
- def process_stats(self):
- """Read in and count all the messages which have been processed"""
- while not self._shutdown.is_set():
- with self._stats_lock:
- self._reqs_10s += get_from_queue(self._reqs)
- self._acks_10s += get_from_queue(self._acks)
- self._naks_10s += get_from_queue(self._naks)
- t = time.time()
- self._reqs_10s = [r for r in self._reqs_10s if r > t - 10]
- self._acks_10s = [r for r in self._acks_10s if r > t - 10]
- self._naks_10s = [r for r in self._naks_10s if r > t - 10]
- self._reqs_1s = [r for r in self._reqs_10s if r > t - 1]
- self._acks_naks_1s = [a for a in self._acks_10s if a > t - 1] + [n for n in self._naks_10s if n > t - 1]
- def print_stats(self, *args):
- """Print the stats to STDOUT"""
- with self._stats_lock:
- print(json.dumps(
- {
- "msg_tot": self._total_reqs + self._total_acks + self._total_naks,
- "msg_reqs": self._total_reqs,
- "msg_ack": self._total_acks,
- "msg_nak": self._total_naks,
- "request_rate_1s": len(self._reqs_1s),
- "request_rate_10s": len(self._reqs_10s)/10,
- "response_rate_10s": (len(self._acks_10s) + len(self._naks_10s))/10,
- "response_rate_1s": len(self._acks_naks_1s)
- },
- indent=4,
- sort_keys=True
- ))
- def start(self):
- """Start the proxy server listening for client connections"""
- print("Starting Foo Proxy")
- self._statsthread = threading.Thread(target=self.process_stats)
- self._statsthread.start()
- signal.signal(signal.SIGUSR2, self.print_stats)
- self._sockserver.serve_forever()
- def stop(self):
- """Stop the server from listening for client connections"""
- print("Stopping Foo Proxy")
- self._sockserver.server_close()
- def __del__(self):
- try:
- self.stop()
- except:
- pass
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement