Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import sys, asyncio
- ################
- # Server class #
- ################
- class Server:
- default_server_address = '127.0.0.1'
- default_server_port = 8888
- # NOTE: you can modify __init__
- def __init__(self,server_address=default_server_address,server_port=default_server_port):
- self.address = server_address
- self.port = server_port
- self.all_clients = set([])
- self.registered_clients = set([])
- self.client_dict={}
- # NOTE: the following method must be implemented for some of our grading tests to work. If you don't implement this method correctly, you will lose some marks!
- # method for registering usernames
- def set_username(self,new_username,writer,old_username=None):
- registered_users = self.get_registered_usernames_list()
- if new_username in registered_users:
- print("@client ERROR User already registered")
- writer.write("@client ERROR User already registered".encode())
- #add client to dictionary and handle changing client's username
- else:
- if len(self.client_dict.items())>0:
- for key, value in self.client_dict.items():
- if value == writer:
- old_username = key
- self.client_dict[new_username]=self.client_dict[old_username]
- del self.client_dict[old_username]
- print("@client username set to " + new_username)
- output = "@client username set to "+new_username
- writer.write(output.encode())
- self.registered_clients.add(writer)
- break
- else:
- self.client_dict[new_username] = writer
- print("@client username set to " + new_username)
- output = "@client username set to "+new_username
- writer.write(output.encode())
- self.registered_clients.add(writer)
- break
- else:
- self.client_dict[new_username] = writer
- print("@client username set to " + new_username)
- output = "@client username set to "+new_username
- writer.write(output.encode())
- self.registered_clients.add(writer)
- # NOTE: this method must be implemented for some of our grading tests to work. If you don't implement this method correctly, you will lose some marks!
- # method that returns all the registered usernames as a list
- def get_registered_usernames_list(self):
- return list(self.client_dict.keys())
- # NOTE: you can modify the implementation of handle_connection (but not its signature)
- @asyncio.coroutine
- def handle_connection(self, reader, writer):
- self.all_clients.add(writer)
- client_addr = writer.get_extra_info('peername')
- print('New client {}'.format(client_addr))
- while True:
- data = yield from reader.read(100)
- if data == None or len(data) == 0:
- break
- message = data.decode()
- if message.startswith("@server "):
- if message.startswith("@server set_my_id(") and message.endswith(")"):
- client_name = message[message.find("(") +1:message.find(")")]
- self.set_username(client_name,writer)
- print(self.registered_clients)
- continue
- else:
- writer.write("@client ERROR Message does not match the username registration format".encode())
- continue
- print("Received {} from {}".format(message, client_addr))
- sender = None
- for client in self.client_dict:
- if self.client_dict[client] == writer:
- sender = client
- break
- if(sender==None):
- print("@client ERROR Cannot send messages as an unregistered user")
- writer.write("@client ERROR Cannot send messages as an unregistered user".encode())
- continue
- #sends message to all clients
- names = self.get_registered_usernames_list()
- is_private = False
- for name in names:
- if message.startswith("@"+name):
- private_writer = self.client_dict[name]
- message = message[1:]
- sender_name=""
- for key, value in self.client_dict.items():
- if value == writer:
- sender_name = key
- index = len(sender_name)
- new_message = "#" +sender_name + ':' + message[len(name):]
- private_writer.write(new_message.encode())
- is_private = True
- if is_private==False:
- for other_writer in self.registered_clients:
- if other_writer != writer:
- new_message = '{}: {}'.format(sender,data.decode())
- other_writer.write(( new_message).encode())
- yield from other_writer.drain()
- print("Closing connection with client {}".format(client_addr))
- writer.close()
- self.all_clients.remove(writer)
- if writer in self.registered_clients:
- self.registered_clients.remove(writer)
- # NOTE: do not modify run
- def run(self):
- loop = asyncio.get_event_loop()
- coro = asyncio.start_server(self.handle_connection,self.address,
- self.port,loop=loop)
- server = loop.run_until_complete(coro)
- print('Serving on {}'.format(server.sockets[0].getsockname()))
- try:
- loop.run_forever()
- except KeyboardInterrupt:
- print('\nGot keyboard interrupt, shutting down',file=sys.stderr)
- for task in asyncio.Task.all_tasks():
- task.cancel()
- server.close()
- loop.run_until_complete(server.wait_closed())
- loop.close()
- # NOTE: do not modify the following two lines
- if __name__ == '__main__':
- Server().run()
- import sys, asyncio
- import aioconsole
- #####################
- # Custom exceptions #
- #####################
- class NoneException(Exception):
- pass
- class ClosingException(Exception):
- pass
- ################
- # Client class #
- ################
- class Client:
- default_server_address = '127.0.0.1'
- default_server_port = 8888
- # NOTE: you can modify __init__
- def __init__(self, server_address=default_server_address, server_port=default_server_port):
- self.server_address = server_address
- self.server_port = server_port
- self.name = None # NOTE: do not remove the attribute from your client implementation, but use it to store the registered username associated to this client. This is a pre-condition for some of our grading tests to work correctly.
- # NOTE: do not modify open_connection
- @asyncio.coroutine
- def open_connection(self,loop):
- reader, writer = yield from asyncio.open_connection(
- self.server_address, self.server_port, loop=loop)
- return reader, writer
- # NOTE: do not modify use_connection
- @asyncio.coroutine
- def use_connection(self,reader, writer):
- yield from asyncio.gather(self.read_from_network(reader,writer),
- self.send_to_server(writer))
- # NOTE: you can modify the implementation of read_from_network (but not its signature)
- @asyncio.coroutine
- def read_from_network(self,reader,writer):
- while True:
- net_message = yield from reader.read(100)
- message = ""
- if(not isinstance(net_message, str)):
- message = net_message.decode()
- if writer.transport.is_closing():
- print('Terminating read from network.')
- break
- elif net_message == None:
- continue
- elif len(net_message) == 0:
- print('The server closed the connection.')
- writer.close()
- break
- elif "@client username set to" in message:
- last_word = message.split()
- client_name = last_word[-1]
- self.name = client_name
- print('\n%s' % message)
- print('>> ',end='',flush=True)
- continue
- elif "@client ERROR" in message:
- print ("\n[server] " + message[8:])
- print('>> ',end='',flush=True)
- continue
- elif "#" in message:
- print('\n[private] %s' % message[1:])
- print('>> ',end='',flush=True)
- else:
- print('\n[public] %s' % message)
- print('>> ',end='',flush=True)
- # NOTE: you can modify the implementation of send_to_server (but not its signature)
- @asyncio.coroutine
- def send_to_server(self,writer):
- try:
- while True:
- original_message = yield from aioconsole.ainput('>> ')
- if original_message != None:
- console_message = original_message.strip()
- if console_message == '':
- continue
- elif console_message == 'close()':
- raise ClosingException()
- elif console_message.startswith('@'):
- if console_message.startswith("@server"):
- if '(' in console_message and ')' in console_message:
- client_name = console_message[console_message.find("(") +1:console_message.find(")")]
- if(client_name.lower() == "client" or client_name.lower() =="server"):
- print("[error]: username cannot be client or server")
- continue
- elif ' ' in client_name:
- print("[error]: username cannot contain spaces")
- continue
- elif console_message.startswith("@"+self.name):
- print("[error] Cannot private message yourself")
- writer.write(console_message.encode())
- #elif self.name != none:
- # print("[error]:Syntax error when sending username to server")
- except ClosingException:
- print('Got close() from user.')
- finally:
- if not writer.transport.is_closing():
- writer.close()
- # NOTE: do not modify run
- def run(self):
- try:
- loop = asyncio.get_event_loop()
- reader,writer=loop.run_until_complete(self.open_connection(loop))
- loop.run_until_complete(self.use_connection(reader,writer))
- except KeyboardInterrupt:
- print('Got Ctrl-C from user.')
- except Exception as e:
- print(e,file=sys.stderr)
- finally:
- loop.close()
- # NOTE: do not modify the following two lines
- if __name__ == '__main__':
- Client().run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement