Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import annotations
- from dataclasses import dataclass
- from functools import reduce
- from hashlib import sha256
- from time import time
- from nacl.public import Box, SealedBox
- from nacl.signing import SigningKey, VerifyKey, SignedMessage
- from queue import SimpleQueue
- from random import randint
- from secrets import token_bytes
- from typing import Protocol
- import readline
- import struct
- """Proof-of-concept gossip protocol implementation. To use in a real
- application, import the module; create new handlers as necessary for
- the duck typed protocols SupportsSendAndDeliverMessage,
- SupportsHandleMessage, and SupportsHandleAction; and register the
- handlers as shown in main. Note that Node.from_seed is used only for
- the active node(s) while Node.__init__ is can be used for neighbors
- for which we know only the public key/address. Most models include
- a data or metadata property for extensibility. The run_tick,
- format_address, and action_count functions can be used without
- modification. Debug message handling can be swapped out from print
- to a custom function using deregister_debug_handler(print) and
- register_debug_handler(custom_func), e.g. to write to a log file.
- """
- def license() -> str:
- """Copyleft (c) 2022 k98kurz
- Permission to use, copy, modify, and/or distribute this software
- for any purpose with or without fee is hereby granted, provided
- that the above copyleft notice and this permission notice appear in
- all copies.
- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
- WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
- WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
- AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
- CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
- OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
- NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
- CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
- """
- return license.__doc__
- # global config and toggle/utility functions
- ENABLE_DEBUG = True
- DISPLAY_SHORT_ADDRESSES = True
- SIGN_MESSAGES = True
- ENCRYPT_MESSAGES = True
- MESSAGE_TTL = 300
- DEBUG_HANDLERS = [print]
- def format_address(address: bytes) -> str:
- global DISPLAY_SHORT_ADDRESSES
- return address.hex()[:8] if DISPLAY_SHORT_ADDRESSES else address.hex()
- def toggle_short_address() -> bool:
- global DISPLAY_SHORT_ADDRESSES
- DISPLAY_SHORT_ADDRESSES = not DISPLAY_SHORT_ADDRESSES
- return DISPLAY_SHORT_ADDRESSES
- def debug(msg: str):
- """Pass debug messages to all debug message handlers."""
- global ENABLE_DEBUG, DEBUG_HANDLERS
- if ENABLE_DEBUG:
- for d in DEBUG_HANDLERS:
- d(msg)
- def register_debug_handler(c: function) -> None:
- """Register a new function for handling debug messages."""
- if not callable(c):
- raise TypeError('Can only register callables as debug handlers.')
- global DEBUG_HANDLERS
- if c not in DEBUG_HANDLERS:
- DEBUG_HANDLERS.append(c)
- def unregister_debug_handler(c: function) -> None:
- """Unregister a function from handling debug messages."""
- if not callable(c):
- raise TypeError('Can only deregister callables as debug handlers.')
- global DEBUG_HANDLERS
- if c in DEBUG_HANDLERS:
- DEBUG_HANDLERS.remove(c)
- def toggle_debug() -> bool:
- global ENABLE_DEBUG
- ENABLE_DEBUG = not ENABLE_DEBUG
- return ENABLE_DEBUG
- def toggle_sign_all_messages() -> bool:
- global SIGN_MESSAGES
- SIGN_MESSAGES = not SIGN_MESSAGES
- return SIGN_MESSAGES
- def toggle_encrypt_messages() -> bool:
- global ENCRYPT_MESSAGES
- ENCRYPT_MESSAGES = not ENCRYPT_MESSAGES
- return ENCRYPT_MESSAGES
- @dataclass
- class Connection:
- """Connection model represent an edge connecting two Nodes together."""
- nodes: set[Node]
- data: dict
- def __init__(self, nodes: list[Node]) -> None:
- if type(nodes) is not list or len(nodes) != 2:
- raise Exception('a Connection must connect exactly 2 nodes')
- self.nodes = set(nodes)
- self.data = {}
- def __hash__(self) -> int:
- """Enable inclusion in sets."""
- node_list = list(self.nodes)
- node_list.sort()
- return hash(node_list[0].address + node_list[1].address)
- @dataclass
- class Message:
- """Message model contains the source, destination, content, and
- optional signature and metadata.
- """
- src: bytes
- dst: bytes
- ts: int
- msg: bytes
- sig: bytes
- metadata: dict
- def __init__(self, src: bytes, dst: bytes, msg: bytes, ts: int = None, sig: bytes = None) -> None:
- self.src = src
- self.dst = dst
- self.ts = int(time()) if ts is None else ts
- self.msg = msg
- self.sig = sig
- self.metadata = {}
- def __repr__(self) -> str:
- return f"{format_address(self.src)}->{format_address(self.dst)}: {format_address(sha256(self.msg).digest())}"
- def __bytes__(self) -> bytes:
- return self.src + self.dst + self.msg
- def __hash__(self) -> int:
- """Enable inclusion in sets."""
- return hash(bytes(self))
- def pack(self) -> bytes:
- """Pack the data with struct."""
- if self.sig is not None:
- fstr = '!32s32si64s' + str(len(self.msg)) + 's'
- return struct.pack(fstr, self.dst, self.src, self.ts, self.sig, self.msg)
- else:
- fstr = '!32s32si' + str(len(self.msg)) + 's'
- return struct.pack(fstr, self.dst, self.src, self.ts, self.msg)
- @classmethod
- def unpack(cls, packed: bytes) -> Message:
- """Unpack the data with struct."""
- global SIGN_MESSAGES
- if SIGN_MESSAGES:
- fstr = '!32s32si64s' + str(len(packed) - 128) + 's'
- (dst, src, ts, sig, msg) = struct.unpack(fstr, packed)
- return Message(src, dst, msg, ts, sig)
- else:
- fstr = '!32s32si' + str(len(packed) - 64) + 's'
- (dst, src, ts, msg) = struct.unpack(fstr, packed)
- return Message(src, dst, msg, ts)
- def sign(self, skey: SigningKey) -> SignedMessage:
- """Generate a signature for the message."""
- sig = skey.sign(bytes(self))
- self.sig = sig[:64]
- return sig
- def verify(self) -> bool:
- """Verify the message signature"""
- try:
- vkey = VerifyKey(self.src)
- sig = SignedMessage(self.sig + bytes(self))
- vkey.verify(sig)
- return True
- except:
- return False
- def encrypt(self, skey: SigningKey) -> None:
- """Encrypt the message by the sender."""
- if bytes(skey.verify_key) != self.src:
- raise ValueError('Must use the skey of the sender to encrypt.')
- privk = skey.to_curve25519_private_key()
- pubk = VerifyKey(self.dst).to_curve25519_public_key()
- box = Box(privk, pubk)
- self.msg = bytes(box.encrypt(self.msg))
- def decrypt(self, skey: SigningKey) -> None:
- """Decrypt the message by the receiver."""
- if bytes(skey.verify_key) != self.dst:
- raise ValueError('Must use the skey of the receiver to decrypt.')
- privk = skey.to_curve25519_private_key()
- pubk = VerifyKey(self.src).to_curve25519_public_key()
- box = Box(privk, pubk)
- self.msg = box.decrypt(self.msg)
- def seal(self) -> None:
- """Encrypt using ephemeral ECDHE."""
- sealed_box = SealedBox(VerifyKey(self.dst).to_curve25519_public_key())
- self.msg = sealed_box.encrypt(self.msg)
- def unseal(self, skey: SigningKey) -> None:
- """Decrypt using ephemeral ECDHE."""
- if bytes(skey.verify_key) != self.dst:
- raise ValueError('Must use the skey of the receiver to decrypt.')
- privk = skey.to_curve25519_private_key()
- sealed_box = SealedBox(privk)
- self.msg = sealed_box.decrypt(self.msg)
- @dataclass
- class Action:
- """Action model contains the name and data for an action a Node will
- take by passing to the registered action handler.
- """
- name: str
- data: dict
- def __init__(self, name: str, data: dict) -> None:
- self.name = name
- self.data = data
- class SupportsSendAndDeliverMessage(Protocol):
- """Duck type protocol for message sender."""
- def send(self, msg: Message) -> None:
- ...
- def deliver(self) -> None:
- ...
- class SupportsHandleMessage(Protocol):
- """Duck type protocol for incoming message handler."""
- def handle(self, msg: Message) -> None:
- ...
- class SupportsHandleAction(Protocol):
- """Duck type protocol for action handler."""
- def handle(self, action: dict) -> None:
- ...
- @dataclass
- class Node:
- """The core model representing a Node and handling its Connections,
- Actions, and Messages. Invoke with Node(address) for neighbors
- and Node.from_seed(seed) for an active node. Optional data
- property for extensibility. The address is the public key bytes
- of the node for when SIGN_MESSAGES is set.
- """
- address: bytes
- msgs_seen: set[bytes]
- connections: set[Connection]
- data: dict
- _seed: bytes
- _skey: SigningKey
- _vkey: VerifyKey
- _inbound: SimpleQueue
- _outbound: SimpleQueue
- _actions: SimpleQueue
- _message_sender: SupportsSendAndDeliverMessage
- _message_handler: SupportsHandleMessage
- _action_handler: SupportsHandleAction
- def __init__(self, address: bytes) -> None:
- """Create a node from its address (public key bytes)."""
- self.address = address
- self.msgs_seen = set()
- self.connections = set()
- self.data = {}
- self._vkey = VerifyKey(address)
- self._seed = None
- self._skey = None
- self._inbound = SimpleQueue()
- self._outbound = SimpleQueue()
- self._actions = SimpleQueue()
- @classmethod
- def from_seed(cls, seed: bytes):
- """Create a node from a seed filling out _skey."""
- skey = SigningKey(seed)
- node = cls(bytes(skey.verify_key))
- node._skey = skey
- node._seed = seed
- return node
- def __hash__(self) -> int:
- """Enable inclusion in sets."""
- return hash(self.address)
- def __lt__(self, other: Node) -> bool:
- return self.address < other.address
- def __repr__(self) -> str:
- if self._seed is not None:
- return "{'address': '" + format_address(self.address) + "','seed':'" + self._seed.hex() + "}"
- else:
- return "{'address': '" + format_address(self.address) + "'}"
- def register_message_sender(self, sndr: SupportsSendAndDeliverMessage) -> None:
- """Register the message sender."""
- if not hasattr(sndr, 'send') or not callable(sndr.send):
- raise TypeError('sndr must fulfill SupportsSendAndDeliverMessage duck type')
- self._message_sender = sndr
- def register_message_handler(self, hndlr: SupportsHandleMessage) -> None:
- """Register the incoming message handler."""
- if not hasattr(hndlr, 'handle') or not callable(hndlr.handle):
- raise TypeError('hndlr must fulfill SupportsHandleMessage duck type')
- self._message_handler = hndlr
- def register_action_handler(self, hndlr: SupportsHandleAction) -> None:
- """Register the action handler."""
- if not hasattr(hndlr, 'handle') or not callable(hndlr.handle):
- raise TypeError('hndlr must fulfill SupportsHandleAction duck type')
- self._action_handler = hndlr
- def add_connection(self, connection: Connection) -> None:
- """Add the specified connection."""
- if not isinstance(connection, Connection):
- raise TypeError('connection must be a Connection')
- self.connections.add(connection)
- def drop_connection(self, connection: Connection) -> None:
- """Drop the specified connection."""
- if not isinstance(connection, Connection):
- raise TypeError('connection must be a Connection')
- self.connections.remove(connection)
- def count_connections(self) -> int:
- return len(self.connections)
- def receive_message(self, message: Message):
- """Queue up an incoming message if its signature is valid or
- ignored.
- """
- if not isinstance(message, Message):
- raise TypeError('message must be a Message')
- global SIGN_MESSAGES, ENCRYPT_MESSAGES, MESSAGE_TTL
- if int(time()) > (message.ts + MESSAGE_TTL):
- debug("Node.receive_message: old message discarded")
- elif message.sig is not None:
- if message.verify():
- if ENCRYPT_MESSAGES:
- message.unseal(self._skey)
- self._inbound.put(message)
- else:
- debug("Node.receive_message: message signature failed verification")
- elif SIGN_MESSAGES:
- debug("Node.receive_message: unsigned message rejected")
- else:
- if ENCRYPT_MESSAGES:
- message.unseal(self._skey)
- self._inbound.put(message)
- def send_message(self, dst: bytes, msg: bytes):
- """Queue up an outgoing message. Sign if necessary and possible."""
- if type(dst) is not bytes:
- raise TypeError("dst must be bytes")
- if type(msg) is not bytes:
- raise TypeError("msg must be bytes")
- message = Message(self.address, dst, msg)
- global ENCRYPT_MESSAGES
- if ENCRYPT_MESSAGES:
- message.seal()
- if self._skey is not None:
- message.sign(self._skey)
- if len(self.connections):
- if len([c for c in self.connections if dst in [n.address for n in c.nodes]]):
- self._outbound.put(message)
- else:
- debug("cannot deliver message due to lack of connection")
- else:
- self._outbound.put(message)
- def queue_action(self, act: Action) -> None:
- """Queue an action to be processed by the action handler."""
- if not isinstance(act, Action):
- raise TypeError('act must be an Action')
- self._actions.put(act)
- def process(self):
- """Process actions for this node once."""
- if self._outbound.qsize() > 0 and self._message_sender is not None:
- self._message_sender.send(self._outbound.get())
- if self._inbound.qsize() > 0 and self._message_handler is not None:
- self._message_handler.handle(self._inbound.get())
- if self._actions.qsize() > 0 and self._action_handler is not None:
- self._action_handler.handle(self._actions.get())
- def action_count(self):
- """Count the size of pending messages and actions."""
- return self._outbound.qsize() + self._inbound.qsize() + self._actions.qsize()
- @dataclass
- class MessageSender:
- """Example message sender that does not use a network stack."""
- nodes: set[Node]
- message_queue: SimpleQueue
- dead_letters: list
- data: dict
- def __init__(self) -> None:
- self.nodes = set()
- self.message_queue = SimpleQueue()
- self.dead_letters = []
- self.data = {}
- def register_node(self, node: Node) -> None:
- """Register the specified node for message delivery."""
- if not isinstance(node, Node):
- raise TypeError('node must be a Node')
- self.nodes.add(node)
- self.check_dead_letters(node)
- def register_nodes(self, nodes: list[Node]) -> None:
- """Register the specified list of nodes for message delivery."""
- if not isinstance(nodes, list):
- raise TypeError('nodes must be a list of Nodes')
- for n in nodes:
- self.register_node(n)
- def send(self, msg: Message) -> None:
- """Queue up the specified message for delivery."""
- if not isinstance(msg, Message):
- raise TypeError('msg must be a Message')
- debug(f"MessageSender.send(): {msg}")
- self.message_queue.put(msg)
- def check_queue(self) -> int:
- """Return pending message queue size."""
- return self.message_queue.qsize()
- def deliver(self) -> None:
- """Deliver all pending messages that can be delivered."""
- while self.check_queue() > 0:
- msg = self.message_queue.get()
- found = False
- for n in self.nodes:
- if n.address == msg.dst:
- n.receive_message(msg)
- found = True
- debug(f'MessageSender.deliver(): delivered {format_address(sha256(msg.msg).digest())} to {format_address(n.address)}')
- if not found:
- self.dead_letters.append(msg)
- debug('MessageSender.deliver(): dead letter')
- def check_dead_letters(self, node: Node) -> None:
- """Goes through the dead letters to try to deliver to a newly
- registered node.
- """
- if not isinstance(node, Node):
- raise TypeError('node must be a Node')
- delivered = []
- for i, l in enumerate(self.dead_letters):
- if l.dst == node.address:
- node.receive_message(l)
- delivered.append(l)
- self.dead_letters[:] = [l for l in self.dead_letters if l not in delivered]
- @dataclass
- class MessageHandler:
- """Example message handler. Optional data property for extensibility."""
- nodes: set[Node]
- data: dict
- def __init__(self) -> None:
- self.nodes = set()
- self.data = {}
- def register_node(self, node: Node) -> None:
- """Register the specified node for handling incoming messages."""
- if not isinstance(node, Node):
- raise TypeError('node must be a Node')
- self.nodes.add(node)
- def register_nodes(self, nodes: list[Node]) -> None:
- """Register the specified list of nodes for handling incoming
- messages.
- """
- if not isinstance(nodes, list):
- raise TypeError('nodes must be a list of Nodes')
- for n in nodes:
- self.register_node(n)
- def handle(self, msg: Message) -> None:
- """Handle an incoming message."""
- if not isinstance(msg, Message):
- raise TypeError('msg must be a Message')
- debug(f'MessageHandler.handle(): {msg}')
- if msg.dst in [n.address for n in self.nodes]:
- n = [n for n in self.nodes if n.address == msg.dst][0]
- n.queue_action(Action('store_and_forward', {"msg": msg.msg}))
- else:
- debug('MessageHandler.handle(): message dropped')
- @dataclass
- class ActionHandler:
- """Example action handler."""
- node: Node
- other_nodes: set[Node]
- def __init__(self, node: Node, other_nodes: list[Node]) -> None:
- self.node = node
- self.other_nodes = set(other_nodes)
- def handle(self, act: Action) -> None:
- """Handle an action. Limited to store_and_forward action."""
- if act.name == 'store_and_forward':
- if sha256(act.data['msg']).digest() not in self.node.msgs_seen:
- # store
- self.node.msgs_seen.add(sha256(act.data['msg']).digest())
- debug(f"ActionHandler.handle(): store_and_forward [{act.data['msg'].hex()}]")
- # forward
- if self.node.count_connections() > 0:
- # forward to all connected nodes
- for c in self.node.connections:
- n = [n for n in c.nodes if n is not self.node][0]
- self.node.send_message(n.address, act.data['msg'])
- else:
- # forward to 2 nodes at random
- n1 = list(self.other_nodes)[randint(0, len(self.other_nodes)-1)]
- n2 = list(self.other_nodes)[randint(0, len(self.other_nodes)-1)]
- self.node.send_message(n1.address, act.data['msg'])
- self.node.send_message(n2.address, act.data['msg'])
- else:
- debug(f"ActionHandler.handle(): store_and_forward skipped for seen message")
- def run_tick(nodes: list[Node], msg_sender: MessageSender):
- """Run the process for all nodes, then deliver all pending messages."""
- for n in nodes:
- n.process()
- msg_sender.deliver()
- def action_count(nodes: list[Node]):
- """Returns a count of all pending actions and messages."""
- if not isinstance(nodes, list):
- raise TypeError('nodes must be list of Nodes')
- return reduce(lambda c, n: c + n.action_count(), nodes, 0)
- def main():
- # create some nodes
- nodes = [Node.from_seed(token_bytes(32)) for i in range(16)]
- # create handlers
- msg_handler = MessageHandler()
- msg_handler.register_nodes(nodes)
- msg_sender = MessageSender()
- msg_sender.register_nodes(nodes)
- # register handlers
- for n in nodes:
- n.register_message_handler(msg_handler)
- n.register_message_sender(msg_sender)
- n.register_action_handler(ActionHandler(n, [on for on in nodes if on is not n]))
- # flag for exit
- end_signal = False
- # access global variables
- global SIGN_MESSAGES, ENCRYPT_MESSAGES
- while not end_signal:
- data = input("$: ")
- command = data.split(' ')[0].strip()
- data = ' '.join(data.split(' ')[1:]).strip()
- if command in ('quit', 'q'):
- end_signal = True
- elif command in ('list', 'nodes', 'l', 'n', 'ln'):
- for n in nodes:
- print(f"{format_address(n.address)}: {[format_address(m) for m in n.msgs_seen]}")
- elif command in ('listcon', 'connections', 'lc'):
- connections = set()
- for n in nodes:
- connections = connections.union(n.connections)
- for c in connections:
- cnodes = list(c.nodes)
- print(f"{format_address(cnodes[0].address)} - {format_address(cnodes[1].address)}")
- elif command in ('c', 'connect'):
- for n in nodes:
- others = [o for o in nodes if o is not n]
- for i in range(3):
- o = others[randint(0, len(others)-1)]
- n.add_connection(Connection([n, o]))
- o.add_connection(Connection([n, o]))
- elif command in ('message', 'm'):
- src = nodes[randint(0, len(nodes)-1)]
- message = Message(src.address, src.address, bytes(data, 'utf-8'))
- if ENCRYPT_MESSAGES:
- message.seal()
- if SIGN_MESSAGES:
- message.sign(src._skey)
- src.receive_message(message)
- elif command in ('d', 'debug'):
- print("debug enabled" if toggle_debug() else "debug disabled")
- elif command in ('s', 'short'):
- print("short addresses enabled" if toggle_short_address() else "short addresses disabled")
- elif command in ('r', 'run'):
- while action_count(nodes) > 0:
- run_tick(nodes, msg_sender)
- elif command in ('h', 'help', '?'):
- print("options:\t[l|ln|nodes|list] to list nodes and messages seen by each")
- print("\t\t[m|message] {str} to send a message")
- print("\t\t[c|connect] to connect nodes together randomly")
- print("\t\t[lc|listcon|connections] list all connections")
- print("\t\t[q|quit] to end")
- print("\t\t[h|help|?] display this text")
- print("\t\t[d|debug] to toggle debug messages")
- print("\t\t[s|short] to toggle displaying short address format")
- print("\t\t[r|run] to run until no pending actions remain")
- print("\t\tanything else to process a tick")
- else:
- run_tick(nodes, msg_sender)
- if __name__ == '__main__':
- main()
Add Comment
Please, Sign In to add comment