Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class DecisionTree():
- def __init__(self, max_depth=float("inf")):
- self.root = None
- self.max_depth = max_depth
- self.depth = 0
- def histograms(self, column, labels, threshold):
- left_hist, right_hist = {}, {}
- for f in range(len(column)):
- l = labels[f]
- if type(threshold) == str:
- if column[f] == threshold:
- if l in right_hist.keys():
- right_hist[l] += 1
- else:
- right_hist[l] = 1
- else:
- if l in left_hist.keys():
- left_hist[l] += 1
- else:
- left_hist[l] = 1
- else:
- if column[f] > threshold:
- if l in right_hist.keys():
- right_hist[l] += 1
- else:
- right_hist[l] = 1
- else:
- if l in left_hist.keys():
- left_hist[l] += 1
- else:
- left_hist[l] = 1
- return left_hist, right_hist
- def split_data(self, data, labels, split_rule):
- feature, threshold = split_rule
- left_data, left_labels = [], []
- right_data, right_labels = [], []
- for d in range(len(data)):
- datum = data[d]
- label = labels[d]
- if type(threshold) == str:
- if datum[feature] == threshold:
- right_data.append(datum)
- right_labels.append(label)
- else:
- left_data.append(datum)
- left_labels.append(label)
- else:
- if datum[feature] > threshold:
- right_data.append(datum)
- right_labels.append(label)
- else:
- left_data.append(datum)
- left_labels.append(label)
- return np.array(left_data), np.array(left_labels), np.array(right_data), np.array(right_labels)
- def impurity(self, left_label_hist, right_label_hist):
- left_total = sum(left_label_hist.values())
- left_entropy = -sum(map(lambda x: x / left_total * np.log2(x / left_total), left_label_hist.values()))
- right_total = sum(right_label_hist.values())
- right_entropy = -sum(map(lambda x: x / right_total * np.log2(x / right_total), right_label_hist.values()))
- return (left_total * left_entropy + right_total * right_entropy) / (left_total + right_total)
- def segmenter(self, data, labels):
- min_val = 1
- min_feat = 0
- threshold = float("inf")
- for c in range(len(data[0])):
- col = data[:,c]
- unique = np.unique(col)
- for i in range(len(unique)):
- val = unique[i]
- left_hist, right_hist = self.histograms(col, labels, val)
- H = self.impurity(left_hist, right_hist)
- if H < min_val:
- min_val, min_feat, threshold = H, c, val
- return [min_feat, threshold]
- def train(self, data, labels):
- if len(np.unique(labels)) == 1:
- return Leaf(labels[0])
- left, right = None, None
- split = self.segmenter(data, labels)
- feature, threshold = split
- left_hist, right_hist = self.histograms(data[:,feature], labels, threshold)
- if len(right_hist) == 0 or self.depth == self.max_depth:
- return Leaf(max(left_hist, key=lambda k: left_hist[k]))
- # print(split, left_hist, right_hist)
- left_data, left_labels, right_data, right_labels = self.split_data(data, labels, split)
- n = Inner(split, self.train(left_data, left_labels), self.train(right_data, right_labels))
- self.root = n
- return n
- def predict(self, data):
- node = self.root
- while not node.isLeaf():
- feature, threshold = node.split_rule
- # print(feature, threshold, data[feature])
- if type(threshold) == str:
- if data[feature] == threshold:
- # print("RIGHT")
- node = node.right
- else:
- # print("LEFT")
- node = node.left
- else:
- if data[feature] > threshold:
- # print("RIGHT")
- node = node.right
- else:
- # print("LEFT")
- node = node.left
- # print(node)
- return node.label
- class Node():
- def isLeaf(self):
- pass
- class Inner(Node):
- def __init__(self, split_rule, left, right):
- self.split_rule = split_rule
- self.left = left
- self.right = right
- def isLeaf(self):
- return False
- def __str__(self):
- return str(self.split_rule)
- class Leaf(Node):
- def __init__(self, label):
- self.label = label
- def isLeaf(self):
- return True
- def __str__(self):
- return str(self.label)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement