Advertisement
Guest User

boropeder

a guest
May 24th, 2019
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.17 KB | None | 0 0
  1. from math import log
  2.  
  3.  
  4. def unique_counts(rows):
  5. """Create counts of possible results (the last column of
  6. each row is the result)
  7.  
  8. :param rows: dataset
  9. :type rows: list
  10. :return: dictionary of possible classes as keys and count
  11. as values
  12. :rtype: dict
  13. """
  14. results = {}
  15. for row in rows:
  16. # The result is the last column
  17. r = row[len(row) - 1]
  18. if r not in results:
  19. results[r] = 0
  20. results[r] += 1
  21. return results
  22.  
  23.  
  24. def gini_impurity(rows):
  25. """Probability that a randomly placed item will
  26. be in the wrong category
  27.  
  28. :param rows: dataset
  29. :type rows: list
  30. :return: Gini impurity
  31. :rtype: float
  32. """
  33. total = len(rows)
  34. counts = unique_counts(rows)
  35. imp = 0
  36. for k1 in counts:
  37. p1 = float(counts[k1]) / total
  38. for k2 in counts:
  39. if k1 == k2:
  40. continue
  41. p2 = float(counts[k2]) / total
  42. imp += p1 * p2
  43. return imp
  44.  
  45.  
  46. def entropy(rows):
  47. """Entropy is the sum of p(x)log(p(x)) across all
  48. the different possible results
  49.  
  50. :param rows: dataset
  51. :type rows: list
  52. :return: entropy value
  53. :rtype: float
  54. """
  55. log2 = lambda x: log(x) / log(2)
  56. results = unique_counts(rows)
  57. # Now calculate the entropy
  58. ent = 0.0
  59. for r in results.keys():
  60. p = float(results[r]) / len(rows)
  61. ent = ent - p * log2(p)
  62. return ent
  63.  
  64.  
  65. class DecisionNode:
  66. def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  67. """
  68. :param col: index of the column (attribute) of the training set that
  69. is represented with this instance i.e. this node
  70. :type col: int
  71. :param value: the value of the node according to which the partition
  72. in the tree is made
  73. :param results: results for the current branch, value (not None)
  74. only in leaves where decision is made.
  75. :type results: dict
  76. :param tb: branch that divides from the current node when value is
  77. equal to value
  78. :type tb: DecisionNode
  79. :param fb: branch that divides from the current node when value is
  80. different from value
  81. :type fb: DecisionNode
  82. """
  83. self.col = col
  84. self.value = value
  85. self.results = results
  86. self.tb = tb
  87. self.fb = fb
  88.  
  89.  
  90. def compare_numerical(row, column, value):
  91. """Compare row value of the desired column with particular
  92. numerical value
  93.  
  94. :param row: particular row in the set
  95. :type row: list
  96. :param column: index of the column (attribute) of the train set
  97. :type column: int
  98. :param value: the value of the node according to which the
  99. partition in the tree is made
  100. :type value: int or float
  101. :return: True if the row >= value, else False
  102. :rtype: bool
  103. """
  104. return row[column] >= value
  105.  
  106.  
  107. def compare_nominal(row, column, value):
  108. """Compare row value of the desired column with particular
  109. nominal value
  110.  
  111. :param row: particular row in the set
  112. :type row: list
  113. :param column: index of the column (attribute) of the train set
  114. :type column: int
  115. :param value: the value of the node according to which the
  116. partition in the tree is made
  117. :type value: str
  118. :return: True if the row == value, else False
  119. :rtype: bool
  120. """
  121. return row[column] == value
  122.  
  123.  
  124. def divide_set(rows, column, value):
  125. """Divides a set on a specific column. Can handle numeric
  126. or nominal values.
  127.  
  128. :param rows: the train set
  129. :type rows: list(list)
  130. :param column: index of the column (attribute) of the train set
  131. :type column: int
  132. :param value: the value of the node according to which the
  133. partition in the tree for particular branch is made
  134. :type value: int or float or string
  135. :return: divided subsets
  136. :rtype: list, list
  137. """
  138. # Make a function that tells us if a row is in
  139. # the first group (true) or the second group (false)
  140. if isinstance(value, int) or isinstance(value, float):
  141. # if the value for comparison is of type int or float
  142. split_function = compare_numerical
  143. else:
  144. # if the value for comparison is of other type (string)
  145. split_function = compare_nominal
  146.  
  147. # Divide the rows into two sets and return them
  148. # for each row that split_function returns True
  149. set1 = [row for row in rows if
  150. split_function(row, column, value)]
  151. # for each row that split_function returns False
  152. set2 = [row for row in rows if
  153. not split_function(row, column, value)]
  154. return set1, set2
  155.  
  156.  
  157. def build_tree(rows, scoref=entropy):
  158. if len(rows) == 0:
  159. return DecisionNode()
  160. current_score = scoref(rows)
  161.  
  162. # Set up some variables to track the best criteria
  163. best_gain = 0.0
  164. best_criteria = None
  165. best_sets = None
  166.  
  167. column_count = len(rows[0]) - 1
  168. for col in range(0, column_count):
  169. # Generate the list of different values in this column
  170. column_values = {}
  171. for row in rows:
  172. column_values[row[col]] = 1
  173. # Now try dividing the rows up for each value in this column
  174. for value in column_values.keys():
  175. (set1, set2) = divide_set(rows, col, value)
  176.  
  177. # Information gain
  178. p = float(len(set1)) / len(rows)
  179. gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  180. if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  181. best_gain = gain
  182. best_criteria = (col, value)
  183. best_sets = (set1, set2)
  184.  
  185. # Create the subbranches
  186. if best_gain > 0:
  187. true_branch = build_tree(best_sets[0], scoref)
  188. false_branch = build_tree(best_sets[1], scoref)
  189. return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  190. tb=true_branch, fb=false_branch)
  191. else:
  192. return DecisionNode(results=unique_counts(rows))
  193.  
  194.  
  195. def print_tree(tree, indent=''):
  196. # Is this a leaf node?
  197. if tree.results:
  198. print(str(tree.results))
  199. else:
  200. # Print the criteria
  201. print(str(tree.col) + ':' + str(tree.value) + '? ')
  202. # Print the branches
  203. print(indent + 'T->', end='')
  204. print_tree(tree.tb, indent + ' ')
  205. print(indent + 'F->', end='')
  206. print_tree(tree.fb, indent + ' ')
  207.  
  208.  
  209. def classify(observation, tree):
  210. if tree.results:
  211. return tree.results
  212. else:
  213. value = observation[tree.col]
  214. if isinstance(value, int) or isinstance(value, float):
  215. compare = compare_numerical
  216. else:
  217. compare = compare_nominal
  218.  
  219. if compare(observation, tree.col, tree.value):
  220. branch = tree.tb
  221. else:
  222. branch = tree.fb
  223.  
  224. return classify(observation, branch)
  225.  
  226.  
  227. data = [[242.0, 23.2, 25.4, 30.0, 38.4, 13.4, 1],
  228. [290.0, 24.0, 26.3, 31.2, 40.0, 13.8, 1],
  229. [340.0, 23.9, 26.5, 31.1, 39.8, 15.1, 1],
  230. [363.0, 26.3, 29.0, 33.5, 38.0, 13.3, 1],
  231. [430.0, 26.5, 29.0, 34.0, 36.6, 15.1, 1],
  232. [450.0, 26.8, 29.7, 34.7, 39.2, 14.2, 1],
  233. [500.0, 26.8, 29.7, 34.5, 41.1, 15.3, 1],
  234. [390.0, 27.6, 30.0, 35.0, 36.2, 13.4, 1],
  235. [450.0, 27.6, 30.0, 35.1, 39.9, 13.8, 1],
  236. [500.0, 28.5, 30.7, 36.2, 39.3, 13.7, 1],
  237. [475.0, 28.4, 31.0, 36.2, 39.4, 14.1, 1],
  238. [500.0, 28.7, 31.0, 36.2, 39.7, 13.3, 1],
  239. [500.0, 29.1, 31.5, 36.4, 37.8, 12.0, 1],
  240. [500.0, 29.5, 32.0, 37.3, 37.3, 13.6, 1],
  241. [600.0, 29.4, 32.0, 37.2, 40.2, 13.9, 1],
  242. [600.0, 29.4, 32.0, 37.2, 41.5, 15.0, 1],
  243. [700.0, 30.4, 33.0, 38.3, 38.8, 13.8, 1],
  244. [700.0, 30.4, 33.0, 38.5, 38.8, 13.5, 1],
  245. [610.0, 30.9, 33.5, 38.6, 40.5, 13.3, 1],
  246. [650.0, 31.0, 33.5, 38.7, 37.4, 14.8, 1],
  247. [575.0, 31.3, 34.0, 39.5, 38.3, 14.1, 1],
  248. [685.0, 31.4, 34.0, 39.2, 40.8, 13.7, 1],
  249. [620.0, 31.5, 34.5, 39.7, 39.1, 13.3, 1],
  250. [680.0, 31.8, 35.0, 40.6, 38.1, 15.1, 1],
  251. [700.0, 31.9, 35.0, 40.5, 40.1, 13.8, 1],
  252. [725.0, 31.8, 35.0, 40.9, 40.0, 14.8, 1],
  253. [720.0, 32.0, 35.0, 40.6, 40.3, 15.0, 1],
  254. [714.0, 32.7, 36.0, 41.5, 39.8, 14.1, 1],
  255. [850.0, 32.8, 36.0, 41.6, 40.6, 14.9, 1],
  256. [1000.0, 33.5, 37.0, 42.6, 44.5, 15.5, 1],
  257. [920.0, 35.0, 38.5, 44.1, 40.9, 14.3, 1],
  258. [955.0, 35.0, 38.5, 44.0, 41.1, 14.3, 1],
  259. [925.0, 36.2, 39.5, 45.3, 41.4, 14.9, 1],
  260. [975.0, 37.4, 41.0, 45.9, 40.6, 14.7, 1],
  261. [950.0, 38.0, 41.0, 46.5, 37.9, 13.7, 1],
  262. [270.0, 23.6, 26.0, 28.7, 29.2, 14.8, 2],
  263. [270.0, 24.1, 26.5, 29.3, 27.8, 14.5, 2],
  264. [306.0, 25.6, 28.0, 30.8, 28.5, 15.2, 2],
  265. [540.0, 28.5, 31.0, 34.0, 31.6, 19.3, 2],
  266. [800.0, 33.7, 36.4, 39.6, 29.7, 16.6, 2],
  267. [1000.0, 37.3, 40.0, 43.5, 28.4, 15.0, 2],
  268. [40.0, 12.9, 14.1, 16.2, 25.6, 14.0, 3],
  269. [69.0, 16.5, 18.2, 20.3, 26.1, 13.9, 3],
  270. [78.0, 17.5, 18.8, 21.2, 26.3, 13.7, 3],
  271. [87.0, 18.2, 19.8, 22.2, 25.3, 14.3, 3],
  272. [120.0, 18.6, 20.0, 22.2, 28.0, 16.1, 3],
  273. [0.0, 19.0, 20.5, 22.8, 28.4, 14.7, 3],
  274. [110.0, 19.1, 20.8, 23.1, 26.7, 14.7, 3],
  275. [120.0, 19.4, 21.0, 23.7, 25.8, 13.9, 3],
  276. [150.0, 20.4, 22.0, 24.7, 23.5, 15.2, 3],
  277. [145.0, 20.5, 22.0, 24.3, 27.3, 14.6, 3],
  278. [160.0, 20.5, 22.5, 25.3, 27.8, 15.1, 3],
  279. [140.0, 21.0, 22.5, 25.0, 26.2, 13.3, 3],
  280. [160.0, 21.1, 22.5, 25.0, 25.6, 15.2, 3],
  281. [169.0, 22.0, 24.0, 27.2, 27.7, 14.1, 3],
  282. [161.0, 22.0, 23.4, 26.7, 25.9, 13.6, 3],
  283. [200.0, 22.1, 23.5, 26.8, 27.6, 15.4, 3],
  284. [180.0, 23.6, 25.2, 27.9, 25.4, 14.0, 3],
  285. [290.0, 24.0, 26.0, 29.2, 30.4, 15.4, 3],
  286. [272.0, 25.0, 27.0, 30.6, 28.0, 15.6, 3],
  287. [390.0, 29.5, 31.7, 35.0, 27.1, 15.3, 3],
  288. [55.0, 13.5, 14.7, 16.5, 41.5, 14.1, 4],
  289. [60.0, 14.3, 15.5, 17.4, 37.8, 13.3, 4],
  290. [90.0, 16.3, 17.7, 19.8, 37.4, 13.5, 4],
  291. [120.0, 17.5, 19.0, 21.3, 39.4, 13.7, 4],
  292. [150.0, 18.4, 20.0, 22.4, 39.7, 14.7, 4],
  293. [140.0, 19.0, 20.7, 23.2, 36.8, 14.2, 4],
  294. [170.0, 19.0, 20.7, 23.2, 40.5, 14.7, 4],
  295. [145.0, 19.8, 21.5, 24.1, 40.4, 13.1, 4],
  296. [200.0, 21.2, 23.0, 25.8, 40.1, 14.2, 4],
  297. [273.0, 23.0, 25.0, 28.0, 39.6, 14.8, 4],
  298. [300.0, 24.0, 26.0, 29.0, 39.2, 14.6, 4],
  299. [6.7, 9.3, 9.8, 10.8, 16.1, 9.7, 5],
  300. [7.5, 10.0, 10.5, 11.6, 17.0, 10.0, 5],
  301. [7.0, 10.1, 10.6, 11.6, 14.9, 9.9, 5],
  302. [9.7, 10.4, 11.0, 12.0, 18.3, 11.5, 5],
  303. [9.8, 10.7, 11.2, 12.4, 16.8, 10.3, 5],
  304. [8.7, 10.8, 11.3, 12.6, 15.7, 10.2, 5],
  305. [10.0, 11.3, 11.8, 13.1, 16.9, 9.8, 5],
  306. [9.9, 11.3, 11.8, 13.1, 16.9, 8.9, 5],
  307. [9.8, 11.4, 12.0, 13.2, 16.7, 8.7, 5],
  308. [12.2, 11.5, 12.2, 13.4, 15.6, 10.4, 5],
  309. [13.4, 11.7, 12.4, 13.5, 18.0, 9.4, 5],
  310. [12.2, 12.1, 13.0, 13.8, 16.5, 9.1, 5],
  311. [19.7, 13.2, 14.3, 15.2, 18.9, 13.6, 5],
  312. [19.9, 13.8, 15.0, 16.2, 18.1, 11.6, 5],
  313. [200.0, 30.0, 32.3, 34.8, 16.0, 9.7, 6],
  314. [300.0, 31.7, 34.0, 37.8, 15.1, 11.0, 6],
  315. [300.0, 32.7, 35.0, 38.8, 15.3, 11.3, 6],
  316. [300.0, 34.8, 37.3, 39.8, 15.8, 10.1, 6],
  317. [430.0, 35.5, 38.0, 40.5, 18.0, 11.3, 6],
  318. [345.0, 36.0, 38.5, 41.0, 15.6, 9.7, 6],
  319. [456.0, 40.0, 42.5, 45.5, 16.0, 9.5, 6],
  320. [510.0, 40.0, 42.5, 45.5, 15.0, 9.8, 6],
  321. [540.0, 40.1, 43.0, 45.8, 17.0, 11.2, 6],
  322. [500.0, 42.0, 45.0, 48.0, 14.5, 10.2, 6],
  323. [567.0, 43.2, 46.0, 48.7, 16.0, 10.0, 6],
  324. [770.0, 44.8, 48.0, 51.2, 15.0, 10.5, 6],
  325. [950.0, 48.3, 51.7, 55.1, 16.2, 11.2, 6],
  326. [1250.0, 52.0, 56.0, 59.7, 17.9, 11.7, 6],
  327. [1600.0, 56.0, 60.0, 64.0, 15.0, 9.6, 6],
  328. [1550.0, 56.0, 60.0, 64.0, 15.0, 9.6, 6],
  329. [1650.0, 59.0, 63.4, 68.0, 15.9, 11.0, 6],
  330. [5.9, 7.5, 8.4, 8.8, 24.0, 16.0, 7],
  331. [32.0, 12.5, 13.7, 14.7, 24.0, 13.6, 7],
  332. [40.0, 13.8, 15.0, 16.0, 23.9, 15.2, 7],
  333. [51.5, 15.0, 16.2, 17.2, 26.7, 15.3, 7],
  334. [70.0, 15.7, 17.4, 18.5, 24.8, 15.9, 7],
  335. [100.0, 16.2, 18.0, 19.2, 27.2, 17.3, 7],
  336. [78.0, 16.8, 18.7, 19.4, 26.8, 16.1, 7],
  337. [80.0, 17.2, 19.0, 20.2, 27.9, 15.1, 7],
  338. [85.0, 17.8, 19.6, 20.8, 24.7, 14.6, 7],
  339. [85.0, 18.2, 20.0, 21.0, 24.2, 13.2, 7],
  340. [110.0, 19.0, 21.0, 22.5, 25.3, 15.8, 7],
  341. [115.0, 19.0, 21.0, 22.5, 26.3, 14.7, 7],
  342. [125.0, 19.0, 21.0, 22.5, 25.3, 16.3, 7],
  343. [130.0, 19.3, 21.3, 22.8, 28.0, 15.5, 7],
  344. [120.0, 20.0, 22.0, 23.5, 26.0, 14.5, 7],
  345. [120.0, 20.0, 22.0, 23.5, 24.0, 15.0, 7],
  346. [130.0, 20.0, 22.0, 23.5, 26.0, 15.0, 7],
  347. [135.0, 20.0, 22.0, 23.5, 25.0, 15.0, 7],
  348. [110.0, 20.0, 22.0, 23.5, 23.5, 17.0, 7],
  349. [130.0, 20.5, 22.5, 24.0, 24.4, 15.1, 7],
  350. [150.0, 20.5, 22.5, 24.0, 28.3, 15.1, 7],
  351. [145.0, 20.7, 22.7, 24.2, 24.6, 15.0, 7],
  352. [150.0, 21.0, 23.0, 24.5, 21.3, 14.8, 7],
  353. [170.0, 21.5, 23.5, 25.0, 25.1, 14.9, 7],
  354. [225.0, 22.0, 24.0, 25.5, 28.6, 14.6, 7],
  355. [145.0, 22.0, 24.0, 25.5, 25.0, 15.0, 7],
  356. [188.0, 22.6, 24.6, 26.2, 25.7, 15.9, 7],
  357. [180.0, 23.0, 25.0, 26.5, 24.3, 13.9, 7],
  358. [197.0, 23.5, 25.6, 27.0, 24.3, 15.7, 7],
  359. [218.0, 25.0, 26.5, 28.0, 25.6, 14.8, 7],
  360. [300.0, 25.2, 27.3, 28.7, 29.0, 17.9, 7],
  361. [260.0, 25.4, 27.5, 28.9, 24.8, 15.0, 7],
  362. [265.0, 25.4, 27.5, 28.9, 24.4, 15.0, 7],
  363. [250.0, 25.4, 27.5, 28.9, 25.2, 15.8, 7],
  364. [250.0, 25.9, 28.0, 29.4, 26.6, 14.3, 7],
  365. [300.0, 26.9, 28.7, 30.1, 25.2, 15.4, 7],
  366. [320.0, 27.8, 30.0, 31.6, 24.1, 15.1, 7],
  367. [514.0, 30.5, 32.8, 34.0, 29.5, 17.7, 7],
  368. [556.0, 32.0, 34.5, 36.5, 28.1, 17.5, 7],
  369. [840.0, 32.5, 35.0, 37.3, 30.8, 20.9, 7],
  370. [685.0, 34.0, 36.5, 39.0, 27.9, 17.6, 7],
  371. [700.0, 34.0, 36.0, 38.3, 27.7, 17.6, 7],
  372. [700.0, 34.5, 37.0, 39.4, 27.5, 15.9, 7],
  373. [690.0, 34.6, 37.0, 39.3, 26.9, 16.2, 7],
  374. [900.0, 36.5, 39.0, 41.4, 26.9, 18.1, 7],
  375. [650.0, 36.5, 39.0, 41.4, 26.9, 14.5, 7],
  376. [820.0, 36.6, 39.0, 41.3, 30.1, 17.8, 7],
  377. [850.0, 36.9, 40.0, 42.3, 28.2, 16.8, 7],
  378. [900.0, 37.0, 40.0, 42.5, 27.6, 17.0, 7],
  379. [1015.0, 37.0, 40.0, 42.4, 29.2, 17.6, 7],
  380. [820.0, 37.1, 40.0, 42.5, 26.2, 15.6, 7],
  381. [1100.0, 39.0, 42.0, 44.6, 28.7, 15.4, 7],
  382. [1000.0, 39.8, 43.0, 45.2, 26.4, 16.1, 7],
  383. [1100.0, 40.1, 43.0, 45.5, 27.5, 16.3, 7],
  384. [1000.0, 40.2, 43.5, 46.0, 27.4, 17.7, 7],
  385. [1000.0, 41.1, 44.0, 46.6, 26.8, 16.3, 7]]
  386.  
  387. if __name__ == "__main__":
  388. index = int(input())
  389.  
  390. new = []
  391.  
  392. sz = 8 * [1]
  393.  
  394. for i in data:
  395. if sz[i[6]] <5:
  396. new.append(i)
  397. sz[i[6]] = sz[i[6]] + 1
  398.  
  399. tree = build_tree(new)
  400.  
  401.  
  402.  
  403. for k in classify(data[index], tree).keys():
  404. print(k)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement