Advertisement
Guest User

asd

a guest
Jun 16th, 2019
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.55 KB | None | 0 0
  1. from math import log
  2.  
  3.  
  4. def unique_counts(rows):
  5. """Create counts of possible results (the last column of
  6. each row is the result)
  7.  
  8. :param rows: dataset
  9. :type rows: list
  10. :return: dictionary of possible classes as keys and count
  11. as values
  12. :rtype: dict
  13. """
  14. results = {}
  15. for row in rows:
  16. # The result is the last column
  17. r = row[len(row) - 1]
  18. if r not in results:
  19. results[r] = 0
  20. results[r] += 1
  21. return results
  22.  
  23.  
  24. def gini_impurity(rows):
  25. """Probability that a randomly placed item will
  26. be in the wrong category
  27.  
  28. :param rows: dataset
  29. :type rows: list
  30. :return: Gini impurity
  31. :rtype: float
  32. """
  33. total = len(rows)
  34. counts = unique_counts(rows)
  35. imp = 0
  36. for k1 in counts:
  37. p1 = float(counts[k1]) / total
  38. for k2 in counts:
  39. if k1 == k2:
  40. continue
  41. p2 = float(counts[k2]) / total
  42. imp += p1 * p2
  43. return imp
  44.  
  45.  
  46. def entropy(rows):
  47. """Entropy is the sum of p(x)log(p(x)) across all
  48. the different possible results
  49.  
  50. :param rows: dataset
  51. :type rows: list
  52. :return: entropy value
  53. :rtype: float
  54. """
  55. log2 = lambda x: log(x) / log(2)
  56. results = unique_counts(rows)
  57. # Now calculate the entropy
  58. ent = 0.0
  59. for r in results.keys():
  60. p = float(results[r]) / len(rows)
  61. ent = ent - p * log2(p)
  62. return ent
  63.  
  64.  
  65. class DecisionNode:
  66. def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  67. """
  68. :param col: index of the column (attribute) of the training set that
  69. is represented with this instance i.e. this node
  70. :type col: int
  71. :param value: the value of the node according to which the partition
  72. in the tree is made
  73. :param results: results for the current branch, value (not None)
  74. only in leaves where decision is made.
  75. :type results: dict
  76. :param tb: branch that divides from the current node when value is
  77. equal to value
  78. :type tb: DecisionNode
  79. :param fb: branch that divides from the current node when value is
  80. different from value
  81. :type fb: DecisionNode
  82. """
  83. self.col = col
  84. self.value = value
  85. self.results = results
  86. self.tb = tb
  87. self.fb = fb
  88.  
  89.  
  90. def compare_numerical(row, column, value):
  91. """Compare row value of the desired column with particular
  92. numerical value
  93.  
  94. :param row: particular row in the set
  95. :type row: list
  96. :param column: index of the column (attribute) of the train set
  97. :type column: int
  98. :param value: the value of the node according to which the
  99. partition in the tree is made
  100. :type value: int or float
  101. :return: True if the row >= value, else False
  102. :rtype: bool
  103. """
  104. return row[column] >= value
  105.  
  106.  
  107. def compare_nominal(row, column, value):
  108. """Compare row value of the desired column with particular
  109. nominal value
  110.  
  111. :param row: particular row in the set
  112. :type row: list
  113. :param column: index of the column (attribute) of the train set
  114. :type column: int
  115. :param value: the value of the node according to which the
  116. partition in the tree is made
  117. :type value: str
  118. :return: True if the row == value, else False
  119. :rtype: bool
  120. """
  121. return row[column] == value
  122.  
  123.  
  124. def divide_set(rows, column, value):
  125. """Divides a set on a specific column. Can handle numeric
  126. or nominal values.
  127.  
  128. :param rows: the train set
  129. :type rows: list(list)
  130. :param column: index of the column (attribute) of the train set
  131. :type column: int
  132. :param value: the value of the node according to which the
  133. partition in the tree for particular branch is made
  134. :type value: int or float or string
  135. :return: divided subsets
  136. :rtype: list, list
  137. """
  138. # Make a function that tells us if a row is in
  139. # the first group (true) or the second group (false)
  140. if isinstance(value, int) or isinstance(value, float):
  141. # if the value for comparison is of type int or float
  142. split_function = compare_numerical
  143. else:
  144. # if the value for comparison is of other type (string)
  145. split_function = compare_nominal
  146.  
  147. # Divide the rows into two sets and return them
  148. # for each row that split_function returns True
  149. set1 = [row for row in rows if
  150. split_function(row, column, value)]
  151. # for each row that split_function returns False
  152. set2 = [row for row in rows if
  153. not split_function(row, column, value)]
  154. return set1, set2
  155.  
  156.  
  157. def build_tree(rows, scoref=entropy):
  158. if len(rows) == 0:
  159. return DecisionNode()
  160. current_score = scoref(rows)
  161.  
  162. # Set up some variables to track the best criteria
  163. best_gain = 0.0
  164. best_criteria = None
  165. best_sets = None
  166.  
  167. column_count = len(rows[0]) - 1
  168. for col in range(0, column_count):
  169. # Generate the list of different values in this column
  170. column_values = {}
  171. for row in rows:
  172. column_values[row[col]] = 1
  173. # Now try dividing the rows up for each value in this column
  174. for value in column_values.keys():
  175. (set1, set2) = divide_set(rows, col, value)
  176.  
  177. # Information gain
  178. p = float(len(set1)) / len(rows)
  179. gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  180. if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  181. best_gain = gain
  182. best_criteria = (col, value)
  183. best_sets = (set1, set2)
  184.  
  185. # Create the subbranches
  186. if best_gain > 0:
  187. true_branch = build_tree(best_sets[0], scoref)
  188. false_branch = build_tree(best_sets[1], scoref)
  189. return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  190. tb=true_branch, fb=false_branch)
  191. else:
  192. return DecisionNode(results=unique_counts(rows))
  193.  
  194.  
  195. def print_tree(tree, indent='', level=0):
  196. # Is this a leaf node?
  197. if tree.results:
  198. print(str(tree.results))
  199. else:
  200. # Print the criteria
  201. print(str(tree.col) + ':' + str(tree.value) + '? ' + 'Level= (' + str(level) + ')')
  202. # Print the branches
  203. print(indent + 'T-> ', end='')
  204. print_tree(tree.tb, indent + ' ', level + 1)
  205. print(indent + 'F-> ', end='')
  206. print_tree(tree.fb, indent + ' ', level + 1)
  207.  
  208.  
  209. def classify(observation, tree):
  210. if tree.results:
  211. return tree.results
  212. else:
  213. value = observation[tree.col]
  214. if isinstance(value, int) or isinstance(value, float):
  215. compare = compare_numerical
  216. else:
  217. compare = compare_nominal
  218.  
  219. if compare(observation, tree.col, tree.value):
  220. branch = tree.tb
  221. else:
  222. branch = tree.fb
  223.  
  224. return classify(observation, branch)
  225.  
  226.  
  227. trainingData = [
  228. [6.3, 2.9, 5.6, 1.8, 'I. virginica'],
  229. [6.5, 3.0, 5.8, 2.2, 'I. virginica'],
  230. [7.6, 3.0, 6.6, 2.1, 'I. virginica'],
  231. [4.9, 2.5, 4.5, 1.7, 'I. virginica'],
  232. [7.3, 2.9, 6.3, 1.8, 'I. virginica'],
  233. [6.7, 2.5, 5.8, 1.8, 'I. virginica'],
  234. [7.2, 3.6, 6.1, 2.5, 'I. virginica'],
  235. [6.5, 3.2, 5.1, 2.0, 'I. virginica'],
  236. [6.4, 2.7, 5.3, 1.9, 'I. virginica'],
  237. [6.8, 3.0, 5.5, 2.1, 'I. virginica'],
  238. [5.7, 2.5, 5.0, 2.0, 'I. virginica'],
  239. [5.8, 2.8, 5.1, 2.4, 'I. virginica'],
  240. [6.4, 3.2, 5.3, 2.3, 'I. virginica'],
  241. [6.5, 3.0, 5.5, 1.8, 'I. virginica'],
  242. [7.7, 3.8, 6.7, 2.2, 'I. virginica'],
  243. [7.7, 2.6, 6.9, 2.3, 'I. virginica'],
  244. [6.0, 2.2, 5.0, 1.5, 'I. virginica'],
  245. [6.9, 3.2, 5.7, 2.3, 'I. virginica'],
  246. [5.6, 2.8, 4.9, 2.0, 'I. virginica'],
  247. [7.7, 2.8, 6.7, 2.0, 'I. virginica'],
  248. [6.3, 2.7, 4.9, 1.8, 'I. virginica'],
  249. [6.7, 3.3, 5.7, 2.1, 'I. virginica'],
  250. [7.2, 3.2, 6.0, 1.8, 'I. virginica'],
  251. [6.2, 2.8, 4.8, 1.8, 'I. virginica'],
  252. [6.1, 3.0, 4.9, 1.8, 'I. virginica'],
  253. [6.4, 2.8, 5.6, 2.1, 'I. virginica'],
  254. [7.2, 3.0, 5.8, 1.6, 'I. virginica'],
  255. [7.4, 2.8, 6.1, 1.9, 'I. virginica'],
  256. [7.9, 3.8, 6.4, 2.0, 'I. virginica'],
  257. [6.4, 2.8, 5.6, 2.2, 'I. virginica'],
  258. [6.3, 2.8, 5.1, 1.5, 'I. virginica'],
  259. [6.1, 2.6, 5.6, 1.4, 'I. virginica'],
  260. [7.7, 3.0, 6.1, 2.3, 'I. virginica'],
  261. [6.3, 3.4, 5.6, 2.4, 'I. virginica'],
  262. [5.1, 3.5, 1.4, 0.2, 'I. setosa'],
  263. [4.9, 3.0, 1.4, 0.2, 'I. setosa'],
  264. [4.7, 3.2, 1.3, 0.2, 'I. setosa'],
  265. [4.6, 3.1, 1.5, 0.2, 'I. setosa'],
  266. [5.0, 3.6, 1.4, 0.2, 'I. setosa'],
  267. [5.4, 3.9, 1.7, 0.4, 'I. setosa'],
  268. [4.6, 3.4, 1.4, 0.3, 'I. setosa'],
  269. [5.0, 3.4, 1.5, 0.2, 'I. setosa'],
  270. [4.4, 2.9, 1.4, 0.2, 'I. setosa'],
  271. [4.9, 3.1, 1.5, 0.1, 'I. setosa'],
  272. [5.4, 3.7, 1.5, 0.2, 'I. setosa'],
  273. [4.8, 3.4, 1.6, 0.2, 'I. setosa'],
  274. [4.8, 3.0, 1.4, 0.1, 'I. setosa'],
  275. [4.3, 3.0, 1.1, 0.1, 'I. setosa'],
  276. [5.8, 4.0, 1.2, 0.2, 'I. setosa'],
  277. [5.7, 4.4, 1.5, 0.4, 'I. setosa'],
  278. [5.4, 3.9, 1.3, 0.4, 'I. setosa'],
  279. [5.1, 3.5, 1.4, 0.3, 'I. setosa'],
  280. [5.7, 3.8, 1.7, 0.3, 'I. setosa'],
  281. [5.1, 3.8, 1.5, 0.3, 'I. setosa'],
  282. [5.4, 3.4, 1.7, 0.2, 'I. setosa'],
  283. [5.1, 3.7, 1.5, 0.4, 'I. setosa'],
  284. [4.6, 3.6, 1.0, 0.2, 'I. setosa'],
  285. [5.1, 3.3, 1.7, 0.5, 'I. setosa'],
  286. [4.8, 3.4, 1.9, 0.2, 'I. setosa'],
  287. [5.0, 3.0, 1.6, 0.2, 'I. setosa'],
  288. [5.0, 3.4, 1.6, 0.4, 'I. setosa'],
  289. [5.2, 3.5, 1.5, 0.2, 'I. setosa'],
  290. [5.2, 3.4, 1.4, 0.2, 'I. setosa'],
  291. [5.5, 2.3, 4.0, 1.3, 'I. versicolor'],
  292. [6.5, 2.8, 4.6, 1.5, 'I. versicolor'],
  293. [5.7, 2.8, 4.5, 1.3, 'I. versicolor'],
  294. [6.3, 3.3, 4.7, 1.6, 'I. versicolor'],
  295. [4.9, 2.4, 3.3, 1.0, 'I. versicolor'],
  296. [6.6, 2.9, 4.6, 1.3, 'I. versicolor'],
  297. [5.2, 2.7, 3.9, 1.4, 'I. versicolor'],
  298. [5.0, 2.0, 3.5, 1.0, 'I. versicolor'],
  299. [5.9, 3.0, 4.2, 1.5, 'I. versicolor'],
  300. [6.0, 2.2, 4.0, 1.0, 'I. versicolor'],
  301. [6.1, 2.9, 4.7, 1.4, 'I. versicolor'],
  302. [5.6, 2.9, 3.6, 1.3, 'I. versicolor'],
  303. [6.7, 3.1, 4.4, 1.4, 'I. versicolor'],
  304. [5.6, 3.0, 4.5, 1.5, 'I. versicolor'],
  305. [5.8, 2.7, 4.1, 1.0, 'I. versicolor'],
  306. [6.2, 2.2, 4.5, 1.5, 'I. versicolor'],
  307. [5.6, 2.5, 3.9, 1.1, 'I. versicolor'],
  308. [5.9, 3.2, 4.8, 1.8, 'I. versicolor'],
  309. [6.1, 2.8, 4.0, 1.3, 'I. versicolor'],
  310. [6.3, 2.5, 4.9, 1.5, 'I. versicolor'],
  311. [6.1, 2.8, 4.7, 1.2, 'I. versicolor'],
  312. [6.4, 2.9, 4.3, 1.3, 'I. versicolor'],
  313. [6.6, 3.0, 4.4, 1.4, 'I. versicolor'],
  314. [6.8, 2.8, 4.8, 1.4, 'I. versicolor'],
  315. [6.7, 3.0, 5.0, 1.7, 'I. versicolor'],
  316. [6.0, 2.9, 4.5, 1.5, 'I. versicolor'],
  317. [5.7, 2.6, 3.5, 1.0, 'I. versicolor'],
  318. [5.5, 2.4, 3.8, 1.1, 'I. versicolor'],
  319. [5.5, 2.4, 3.7, 1.0, 'I. versicolor'],
  320. [5.8, 2.7, 3.9, 1.2, 'I. versicolor'],
  321. [6.0, 2.7, 5.1, 1.6, 'I. versicolor'],
  322. [5.4, 3.0, 4.5, 1.5, 'I. versicolor'],
  323. [6.0, 3.4, 4.5, 1.6, 'I. versicolor'],
  324. [6.7, 3.1, 4.7, 1.5, 'I. versicolor'],
  325. [6.3, 2.3, 4.4, 1.3, 'I. versicolor'],
  326. [5.6, 3.0, 4.1, 1.3, 'I. versicolor'],
  327. [5.5, 2.5, 4.0, 1.3, 'I. versicolor'],
  328. [5.5, 2.6, 4.4, 1.2, 'I. versicolor'],
  329. [6.1, 3.0, 4.6, 1.4, 'I. versicolor'],
  330. [5.8, 2.6, 4.0, 1.2, 'I. versicolor'],
  331. [5.0, 2.3, 3.3, 1.0, 'I. versicolor'],
  332. [5.6, 2.7, 4.2, 1.3, 'I. versicolor'],
  333. [5.7, 3.0, 4.2, 1.2, 'I. versicolor'],
  334. [5.7, 2.9, 4.2, 1.3, 'I. versicolor'],
  335. [6.2, 2.9, 4.3, 1.3, 'I. versicolor'],
  336. [5.1, 2.5, 3.0, 1.1, 'I. versicolor'],
  337. [5.7, 2.8, 4.1, 1.3, 'I. versicolor'],
  338. [6.4, 3.1, 5.5, 1.8, 'I. virginica'],
  339. [6.0, 3.0, 4.8, 1.8, 'I. virginica'],
  340. [6.9, 3.1, 5.4, 2.1, 'I. virginica'],
  341. [6.7, 3.1, 5.6, 2.4, 'I. virginica'],
  342. [6.9, 3.1, 5.1, 2.3, 'I. virginica'],
  343. [5.8, 2.7, 5.1, 1.9, 'I. virginica'],
  344. [6.8, 3.2, 5.9, 2.3, 'I. virginica'],
  345. [6.7, 3.3, 5.7, 2.5, 'I. virginica'],
  346. [6.7, 3.0, 5.2, 2.3, 'I. virginica'],
  347. [6.3, 2.5, 5.0, 1.9, 'I. virginica'],
  348. [6.5, 3.0, 5.2, 2.0, 'I. virginica'],
  349. [6.2, 3.4, 5.4, 2.3, 'I. virginica'],
  350. [4.7, 3.2, 1.6, 0.2, 'I. setosa'],
  351. [4.8, 3.1, 1.6, 0.2, 'I. setosa'],
  352. [5.4, 3.4, 1.5, 0.4, 'I. setosa'],
  353. [5.2, 4.1, 1.5, 0.1, 'I. setosa'],
  354. [5.5, 4.2, 1.4, 0.2, 'I. setosa'],
  355. [4.9, 3.1, 1.5, 0.2, 'I. setosa'],
  356. [5.0, 3.2, 1.2, 0.2, 'I. setosa'],
  357. [5.5, 3.5, 1.3, 0.2, 'I. setosa'],
  358. [4.9, 3.6, 1.4, 0.1, 'I. setosa'],
  359. [4.4, 3.0, 1.3, 0.2, 'I. setosa'],
  360. [5.1, 3.4, 1.5, 0.2, 'I. setosa'],
  361. [5.0, 3.5, 1.3, 0.3, 'I. setosa'],
  362. [4.5, 2.3, 1.3, 0.3, 'I. setosa'],
  363. [4.4, 3.2, 1.3, 0.2, 'I. setosa'],
  364. [5.0, 3.5, 1.6, 0.6, 'I. setosa'],
  365. [5.1, 3.8, 1.9, 0.4, 'I. setosa'],
  366. [4.8, 3.0, 1.4, 0.3, 'I. setosa'],
  367. [5.1, 3.8, 1.6, 0.2, 'I. setosa'],
  368. [5.9, 3.0, 5.1, 1.8, 'I. virginica']
  369. ]
  370.  
  371. if __name__ == "__main__":
  372. att1 = float(input())
  373. att2 = float(input())
  374. att3 = float(input())
  375. att4 = float(input())
  376. planttype = input()
  377. testCase = [att1, att2, att3, att4, planttype]
  378.  
  379. half = int(len(trainingData)/2)
  380. data1 = trainingData[:half]
  381. data2 = trainingData[half:]
  382.  
  383. tree1 = build_tree(data1)
  384. tree2 = build_tree(data2)
  385.  
  386. for key in classify(testCase, tree1).keys():
  387. res1 = key
  388. for key in classify(testCase, tree2).keys():
  389. res2 = key
  390.  
  391. print_tree(tree1)
  392. print_tree(tree2)
  393.  
  394. if res1 == res2:
  395. print(res1)
  396. else:
  397. print("KONTRADIKCIJA")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement