Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # file name: prim.py
- from typing import Dict, List
- # Class for a Directed Maximum weighted tree
- class MaxTree:
- def __init__(self, nodes:List, costs:Dict):
- self.q = nodes.copy() # List of all Nodes, will be empty at the end
- self.edges_cost = costs # Dict such as {(a, b): n}, where a, b are nodes and n is the weight between them
- self.tree_nodes = [] # List of all nodes in the tree
- self.tree_edges = [] # List of edges(tuples of nodes) on the tree
- self.children_of = {} # Dict of children for each node
- self.tree_root = None # Root node of the tree
- self.__build_tree()
- ################################
- # Using Prim's algorithm to create a maximum weighted spanning tree
- def __build_tree(self):
- highest_connection = {} # E
- cost_connection = {} # C
- for node in self.q:
- cost_connection[node] = -1
- highest_connection[node] = None
- ################################
- # Applying the algorithm
- while len(self.q) > 0:
- node = max(self.q, key=lambda node: cost_connection[node])
- self.q.remove(node)
- self.tree_nodes.append(node)
- if highest_connection[node] != None:
- self.tree_edges.append((node, highest_connection[node]))
- for other in self.q:
- if self.edges_cost[(other, node)] > cost_connection[other]:
- cost_connection[other] = self.edges_cost[(other, node)]
- highest_connection[other] = node
- ################################
- # Transforming into directed tree
- edges_temp = self.tree_edges.copy()
- self.tree_root = self.tree_nodes[0]
- for node in self.tree_nodes:
- edges_from_node = [edge for edge in edges_temp if node in edge]
- children = list(set([elem for edge in edges_from_node for elem in edge if elem != node]))
- for edge in edges_from_node:
- edges_temp.remove(edge)
- self.children_of[node] = children
- ################################
- # returns the parent of a node
- def parent_of(self, child):
- for node in self.tree_nodes:
- if child in self.children_of[node]:
- return node
- return None
- def attr_parent_of(self, attr, obj):
- attr_field_name = attr.field_name
- parent_field_name = self.parent_of(attr_field_name)
- parent = None
- for other in obj:
- if other.field_name == parent_field_name:
- parent = other
- return parent
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement