Advertisement
Guest User

Untitled

a guest
Mar 23rd, 2017
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.81 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3. from PIL import Image, ImageDraw
  4.  
  5.  
  6. # Function for reading data from csv file
  7. def read_data(path):
  8. monsters_df = pd.read_csv(path)
  9. return monsters_df
  10.  
  11.  
  12. # Function for calculating entropy
  13. def entropy(data):
  14. _, counts = np.unique(data, return_counts=True)
  15. probabilities = counts / len(data)
  16. return -probabilities.dot(np.log(probabilities))
  17.  
  18.  
  19. # Function for separating data by predicate
  20. def separate(data, predicate):
  21. pred_attr = predicate[0]
  22. pred_value = predicate[1]
  23. pred_type = predicate[2]
  24.  
  25. if pred_type == 0:
  26. data_true = data[data[pred_attr] <= pred_value]
  27. data_false = data[data[pred_attr] > pred_value]
  28. else:
  29. data_true = data[data[pred_attr] == pred_value]
  30. data_false = data[data[pred_attr] != pred_value]
  31. return data_true, data_false
  32.  
  33.  
  34. # Score function
  35. def score(data, predicate):
  36. attr = predicate[0]
  37. value = predicate[1]
  38.  
  39. data_true = data[data[attr] <= value]
  40. data_false = data[data[attr] > value]
  41.  
  42. total_count = np.array(data[attr]).size
  43. true_count = np.array(data_true[attr]).size
  44. false_count = np.array(data_false[attr]).size
  45.  
  46. return entropy(data['type']) - float(true_count * entropy(data_true['type']) +
  47. false_count * entropy(data_false['type'])) / total_count
  48.  
  49.  
  50. # Function to get attribute type (0 for float and 1 for string)
  51. def attribute_type(attr):
  52. attr_type = 0
  53. try:
  54. x = float(attr)
  55. return x, attr_type
  56. except ValueError:
  57. attr_type = 1
  58. return attr, attr_type
  59.  
  60.  
  61. # Main decision tree class
  62. class DecisionTree:
  63. # Build function for decision tree (ID3)
  64. def build(self, x, score_func):
  65. y = np.array(x['type'])
  66. # print y
  67. # Check if all data are in the same class. If it is true, return leaf with string name of this class
  68. if np.unique(y).shape[0] == 1:
  69. return Leaf(y[0])
  70.  
  71. # If it is not true, find most informative predicate
  72. attributes = x.dtypes.index[:-1]
  73.  
  74. max_info_gain = 0
  75. most_inf_predicate = 0
  76.  
  77. for attr in attributes:
  78. attr_values = np.unique(x[attr])
  79. # print "size:"
  80. # print attr_values.size
  81. attr_val, attr_type = attribute_type(attr_values[0])
  82.  
  83. for value in attr_values:
  84. info_gain = score_func(x, [attr, value, attr_type])
  85. # print info_gain
  86. if info_gain > max_info_gain:
  87. max_info_gain = info_gain
  88. most_inf_predicate = [attr, value, attr_type]
  89.  
  90. # Separate data by most informative predicate
  91. left_subtree, right_subtree = separate(x, most_inf_predicate)
  92.  
  93. # Check if one of the subtrees is empty
  94. if left_subtree.empty or right_subtree.empty:
  95. # print "Empty!"
  96. # If one of the subtrees is empty return leaf with the string name of the attribute = Majority(x)
  97. class_names, counts = np.unique(y, return_counts=True)
  98. max_number = np.max(counts)
  99. for i in range(class_names.shape[0]):
  100. if counts[i] == max_number:
  101. return Leaf(class_names[i])
  102. else:
  103. # print(str(most_inf_predicate[0]) + " ? " + str(most_inf_predicate[1]))
  104. # If both subtrees are not empty, then return inner Node with predicate and recursive subtrees
  105. l = DecisionTree().build(right_subtree, score_func)
  106. r = DecisionTree().build(left_subtree, score_func)
  107. return Node(
  108. most_inf_predicate[0], most_inf_predicate[1], most_inf_predicate[2], l, r)
  109.  
  110. # Predict function for decision tree
  111. def predict(self, x):
  112. if isinstance(self, Node):
  113. predicate = self.predicate
  114. attr_val = np.array(x[predicate[0]])
  115. # print(attr_val[0])
  116. if predicate[2] == 0:
  117.  
  118. if attr_val[0] <= predicate[1]:
  119. return self.true_branch.predict(x)
  120. else:
  121. return self.false_branch.predict(x)
  122. else:
  123. if attr_val[0] == predicate[1]:
  124. return self.true_branch.predict(x)
  125. else:
  126. return self.false_branch.predict(x)
  127.  
  128. elif isinstance(self, Leaf):
  129. return self.class_name
  130. return 1
  131.  
  132.  
  133. # Class for decision tree node
  134. class Node(DecisionTree):
  135. def __init__(self, predicate_1, predicate_2, predicate_3, false_branch, true_branch):
  136. self.predicate = [predicate_1, predicate_2, predicate_3]
  137. self.false_branch = false_branch
  138. self.true_branch = true_branch
  139.  
  140.  
  141. class Leaf(DecisionTree):
  142. class_name = None
  143.  
  144. def __init__(self, class_name):
  145. self.class_name = class_name
  146.  
  147.  
  148. def getdepth(tree):
  149. if isinstance(tree, Node):
  150. return 1 + max(getdepth(tree.false_branch), getdepth(tree.true_branch))
  151. else:
  152. return 1
  153.  
  154.  
  155. def getwidth(tree):
  156. if isinstance(tree, Node):
  157. return getwidth(tree.false_branch) + getwidth(tree.true_branch)
  158. else:
  159. return 1
  160.  
  161.  
  162. def drawtree(tree, path='tree.jpg'):
  163. w = getwidth(tree) * 100
  164. h = getdepth(tree) * 100
  165.  
  166. img = Image.new('RGB', (w, h), (255, 255, 255))
  167. draw = ImageDraw.Draw(img)
  168.  
  169. drawnode(draw, tree, w / 2, 20)
  170. img.save(path, 'JPEG')
  171.  
  172.  
  173. def drawnode(draw, tree, x, y):
  174. if isinstance(tree, Node):
  175. shift = 100
  176. width1 = getwidth(tree.false_branch) * shift
  177. width2 = getwidth(tree.true_branch) * shift
  178. left = x - (width1 + width2) / 2
  179. right = x + (width1 + width2) / 2
  180.  
  181. # print(tree.predicate[0])
  182. if tree.predicate[2] == 0:
  183. predicate = str(tree.predicate[0]) + "<=" + str(tree.predicate[1])
  184. else:
  185. predicate = str(tree.predicate[0]) + "==" + str(tree.predicate[1])
  186.  
  187. draw.text((x - 20, y - 10), predicate, (0, 0, 0))
  188. draw.line((x, y, left + width1 / 2, y + shift), fill=(255, 0, 0))
  189. draw.line((x, y, right - width2 / 2, y + shift), fill=(255, 0, 0))
  190. drawnode(draw, tree.false_branch, left + width1 / 2, y + shift)
  191. drawnode(draw, tree.true_branch, right - width2 / 2, y + shift)
  192. elif isinstance(tree, Leaf):
  193. draw.text((x - 20, y), tree.class_name, (0, 0, 0))
  194.  
  195.  
  196. def main():
  197. x = read_data("halloween.csv")
  198. # print(x)
  199. dt = DecisionTree().build(x, score)
  200. # print(dt.true_branch.true_branch.predicate[0])
  201. # drawtree(dt)
  202. '''
  203. predicate = ['has_soul', 0.636437818728, 0]
  204. data_true, data_false = separate(x, predicate)
  205. itog, false_d = separate(data_true, predicate)
  206. print false_d
  207. '''
  208.  
  209.  
  210. if __name__ == "__main__":
  211. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement