Advertisement
Guest User

Untitled

a guest
Mar 25th, 2017
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.20 KB | None | 0 0
  1. class DecisionTree():
  2.    
  3.     def __init__(self, max_depth=float("inf")):
  4.         self.root = None
  5.         self.max_depth = max_depth
  6.         self.depth = 0
  7.        
  8.     def histograms(self, column, labels, threshold):
  9.         left_hist, right_hist = {}, {}
  10.         for f in range(len(column)):
  11.             l = labels[f]
  12.             if type(threshold) == str:
  13.                 if column[f] == threshold:
  14.                     if l in right_hist.keys():
  15.                         right_hist[l] += 1
  16.                     else:
  17.                         right_hist[l] = 1
  18.                 else:
  19.                     if l in left_hist.keys():
  20.                         left_hist[l] += 1
  21.                     else:
  22.                         left_hist[l] = 1
  23.             else:
  24.                 if column[f] > threshold:
  25.                     if l in right_hist.keys():
  26.                         right_hist[l] += 1
  27.                     else:
  28.                         right_hist[l] = 1
  29.                 else:
  30.                     if l in left_hist.keys():
  31.                         left_hist[l] += 1
  32.                     else:
  33.                         left_hist[l] = 1
  34.         return left_hist, right_hist
  35.    
  36.     def split_data(self, data, labels, split_rule):
  37.         feature, threshold = split_rule
  38.         left_data, left_labels = [], []
  39.         right_data, right_labels = [], []
  40.         for d in range(len(data)):
  41.             datum = data[d]
  42.             label = labels[d]
  43.             if type(threshold) == str:
  44.                 if datum[feature] == threshold:
  45.                     right_data.append(datum)
  46.                     right_labels.append(label)
  47.                 else:
  48.                     left_data.append(datum)
  49.                     left_labels.append(label)
  50.             else:
  51.                 if datum[feature] > threshold:
  52.                     right_data.append(datum)
  53.                     right_labels.append(label)
  54.                 else:
  55.                     left_data.append(datum)
  56.                     left_labels.append(label)
  57.         return np.array(left_data), np.array(left_labels), np.array(right_data), np.array(right_labels)
  58.        
  59.     def impurity(self, left_label_hist, right_label_hist):
  60.         left_total = sum(left_label_hist.values())
  61.         left_entropy = -sum(map(lambda x: x / left_total * np.log2(x / left_total), left_label_hist.values()))
  62.         right_total = sum(right_label_hist.values())
  63.         right_entropy = -sum(map(lambda x: x / right_total * np.log2(x / right_total), right_label_hist.values()))
  64.         return (left_total * left_entropy + right_total * right_entropy) / (left_total + right_total)
  65.        
  66.     def segmenter(self, data, labels):
  67.         min_val = 1
  68.         min_feat = 0
  69.         threshold = float("inf")
  70.         for c in range(len(data[0])):
  71.             col = data[:,c]
  72.             unique = np.unique(col)
  73.             for i in range(len(unique)):
  74.                 val = unique[i]
  75.                 left_hist, right_hist = self.histograms(col, labels, val)
  76.                 H = self.impurity(left_hist, right_hist)
  77.                 if H < min_val:
  78.                     min_val, min_feat, threshold = H, c, val
  79.         return [min_feat, threshold]
  80.        
  81.     def train(self, data, labels):
  82.         if len(np.unique(labels)) == 1:
  83.             return Leaf(labels[0])
  84.         left, right = None, None
  85.         split = self.segmenter(data, labels)
  86.         feature, threshold = split
  87.         left_hist, right_hist = self.histograms(data[:,feature], labels, threshold)
  88.         if len(right_hist) == 0 or self.depth == self.max_depth:
  89.             return Leaf(max(left_hist, key=lambda k: left_hist[k]))
  90. #         print(split, left_hist, right_hist)
  91.         left_data, left_labels, right_data, right_labels = self.split_data(data, labels, split)
  92.         n = Inner(split, self.train(left_data, left_labels), self.train(right_data, right_labels))
  93.         self.root = n
  94.         return n
  95.        
  96.     def predict(self, data):
  97.         node = self.root
  98.         while not node.isLeaf():
  99.             feature, threshold = node.split_rule
  100. #             print(feature, threshold, data[feature])
  101.             if type(threshold) == str:
  102.                 if data[feature] == threshold:
  103. #                     print("RIGHT")
  104.                     node = node.right
  105.                 else:
  106. #                     print("LEFT")
  107.                     node = node.left
  108.             else:
  109.                 if data[feature] > threshold:
  110. #                     print("RIGHT")
  111.                     node = node.right
  112.                 else:
  113. #                     print("LEFT")
  114.                     node = node.left
  115. #         print(node)
  116.         return node.label
  117.    
  118. class Node():
  119.    
  120.     def isLeaf(self):
  121.         pass
  122.    
  123. class Inner(Node):
  124.    
  125.     def __init__(self, split_rule, left, right):
  126.         self.split_rule = split_rule
  127.         self.left = left
  128.         self.right = right
  129.    
  130.     def isLeaf(self):
  131.         return False
  132.        
  133.     def __str__(self):
  134.         return str(self.split_rule)
  135.        
  136. class Leaf(Node):  
  137.    
  138.     def __init__(self, label):
  139.         self.label = label
  140.        
  141.     def isLeaf(self):
  142.         return True
  143.        
  144.     def __str__(self):
  145.         return str(self.label)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement