Advertisement
dan-masek

Untitled

Oct 31st, 2019
418
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.11 KB | None | 0 0
  1. class Node(object):
  2.     def __init__(self, parent, symbol, weight):
  3.         self.parent = parent
  4.         self.symbol = symbol
  5.         self.weight = weight
  6.         self.id = None
  7.         self.nodes = {}
  8.        
  9.     def __str__(self):
  10.         if self.is_root_node:
  11.             return "Root"
  12.        
  13.         label = "Leaf" if self.is_leaf_node else "Node"
  14.        
  15.         return "%s(input='%s',P=%0.4f,id=%s)" % (label, self.matching_string, self.cumulative_weight, self.id)
  16.        
  17.     def dump(self, depth = 0):
  18.         print "%s %s" % ("*" * (depth + 1), self)
  19.         for _,node in sorted(self.nodes.iteritems()):
  20.             node.dump(depth + 1)
  21.        
  22.     def add_children(self, symbols_with_weights):
  23.         for symbol, weight in symbols_with_weights:
  24.             self.nodes[symbol] = Node(self, symbol, weight)
  25.            
  26.     @property
  27.     def is_root_node(self):
  28.         return self.parent is None
  29.        
  30.     @property
  31.     def is_leaf_node(self):
  32.         return not self.nodes
  33.        
  34.     @property
  35.     def leaf_count(self):
  36.         if self.is_leaf_node:
  37.             return 1
  38.         else:
  39.             return sum((node.leaf_count for node in self.nodes.itervalues()))
  40.            
  41.     @property
  42.     def cumulative_weight(self):
  43.         if self.is_root_node:
  44.             return self.weight
  45.         return self.parent.cumulative_weight * self.weight
  46.        
  47.     @property
  48.     def matching_string(self):
  49.         if self.is_root_node:
  50.             return ""
  51.         return self.parent.matching_string + self.symbol
  52.  
  53.  
  54. class RootNode(Node):
  55.     def __init__(self):
  56.         super(RootNode, self).__init__(None, None, 1.0)
  57.  
  58.        
  59. def node_list(tree):
  60.     result = [tree]
  61.     for node in tree.nodes.itervalues():
  62.         result += node_list(node)
  63.     return result
  64.  
  65.    
  66. def highest_probability_leaf(tree):
  67.     if tree.is_leaf_node:
  68.         return tree
  69.        
  70.     candidates = [highest_probability_leaf(node) for node in tree.nodes.itervalues()]
  71.     sorted_candidates = sorted(candidates, key=lambda v:v.cumulative_weight)
  72.     return sorted_candidates[-1]
  73.    
  74.  
  75. def build_tree(input_symbols, alphabet_size):
  76.     tree = RootNode()    
  77.    
  78.     while (tree.leaf_count < alphabet_size - 1):
  79.         print "Leaf count:", tree.leaf_count
  80.         print "Candidates:"
  81.         for node in node_list(tree):
  82.             if node.is_leaf_node:
  83.                 print str(node)
  84.      
  85.         new_node = highest_probability_leaf(tree)
  86.         print "Choice: ", new_node, "\n"
  87.        
  88.         new_node.add_children(input_symbols)
  89.        
  90.     node_id = 0
  91.     leaf_id = 0
  92.     for node in node_list(tree):
  93.         if node.is_leaf_node:
  94.             node.id = leaf_id
  95.             leaf_id += 1
  96.         elif not node.is_root_node:
  97.             node.id = node_id
  98.             node_id += 1            
  99.  
  100.     return tree
  101.    
  102.    
  103. def make_symbols(prob_of_0):
  104.     return [('0', prob_of_0), ('1', 1 - prob_of_0)]
  105.  
  106.  
  107. input_symbols = make_symbols(0.8)
  108.  
  109. tree = build_tree(input_symbols, 36)
  110.  
  111. print "\nTree:"
  112. tree.dump()
  113.  
  114. print "\nNodes:"
  115. for node in node_list(tree):
  116.     print str(node)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement