Advertisement
Guest User

Untitled

a guest
Jan 29th, 2020
118
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 14.75 KB | None | 0 0
  1. from math import log
  2.  
  3.  
  4. def unique_counts(rows):
  5. """Креирај броење на можни резултати (последната колона
  6. во секоја редица е класата)
  7.  
  8. :param rows: dataset
  9. :type rows: list
  10. :return: dictionary of possible classes as keys and count
  11. as values
  12. :rtype: dict
  13. """
  14. results = {}
  15. for row in rows:
  16. # Клацата е последната колона
  17. r = row[len(row) - 1]
  18. if r not in results:
  19. results[r] = 0
  20. results[r] += 1
  21. return results
  22.  
  23.  
  24. def gini_impurity(rows):
  25. """Probability that a randomly placed item will
  26. be in the wrong category
  27.  
  28. :param rows: dataset
  29. :type rows: list
  30. :return: Gini impurity
  31. :rtype: float
  32. """
  33. total = len(rows)
  34. counts = unique_counts(rows)
  35. imp = 0
  36. for k1 in counts:
  37. p1 = float(counts[k1]) / total
  38. for k2 in counts:
  39. if k1 == k2:
  40. continue
  41. p2 = float(counts[k2]) / total
  42. imp += p1 * p2
  43. return imp
  44.  
  45.  
  46. def entropy(rows):
  47. """Ентропијата е сума од p(x)log(p(x)) за сите
  48. можни резултати
  49.  
  50. :param rows: податочно множество
  51. :type rows: list
  52. :return: вредност за ентропијата
  53. :rtype: float
  54. """
  55. log2 = lambda x: log(x) / log(2)
  56. results = unique_counts(rows)
  57. # Пресметка на ентропијата
  58. ent = 0.0
  59. for r in results.keys():
  60. p = float(results[r]) / len(rows)
  61. ent = ent - p * log2(p)
  62. return ent
  63.  
  64.  
  65. class DecisionNode:
  66. def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  67. """
  68. :param col: индексот на колоната (атрибутот) од тренинг множеството
  69. која се претставува со оваа инстанца т.е. со овој јазол
  70. :type col: int
  71. :param value: вредноста на јазолот според кој се дели дрвото
  72. :param results: резултати за тековната гранка, вредност (различна
  73. од None) само кај јазлите-листови во кои се донесува
  74. одлуката.
  75. :type results: dict
  76. :param tb: гранка која се дели од тековниот јазол кога вредноста е
  77. еднаква на value
  78. :type tb: DecisionNode
  79. :param fb: гранка која се дели од тековниот јазол кога вредноста е
  80. различна од value
  81. :type fb: DecisionNode
  82. """
  83. self.col = col
  84. self.value = value
  85. self.results = results
  86. self.tb = tb
  87. self.fb = fb
  88.  
  89.  
  90. def compare_numerical(row, column, value):
  91. """Споредба на вредноста од редицата на посакуваната колона со
  92. зададена нумеричка вредност
  93.  
  94. :param row: дадена редица во податочното множество
  95. :type row: list
  96. :param column: индекс на колоната (атрибутот) од тренирачкото множество
  97. :type column: int
  98. :param value: вредност на јазелот во согласност со кој се прави
  99. поделбата во дрвото
  100. :type value: int or float
  101. :return: True ако редицата >= value, инаку False
  102. :rtype: bool
  103. """
  104. return row[column] >= value
  105.  
  106.  
  107. def compare_nominal(row, column, value):
  108. """Споредба на вредноста од редицата на посакуваната колона со
  109. зададена номинална вредност
  110.  
  111. :param row: дадена редица во податочното множество
  112. :type row: list
  113. :param column: индекс на колоната (атрибутот) од тренирачкото множество
  114. :type column: int
  115. :param value: вредност на јазелот во согласност со кој се прави
  116. поделбата во дрвото
  117. :type value: str
  118. :return: True ако редицата == value, инаку False
  119. :rtype: bool
  120. """
  121. return row[column] == value
  122.  
  123.  
  124. def divide_set(rows, column, value):
  125. """Поделба на множеството според одредена колона. Може да се справи
  126. со нумерички или номинални вредности.
  127.  
  128. :param rows: тренирачко множество
  129. :type rows: list(list)
  130. :param column: индекс на колоната (атрибутот) од тренирачкото множество
  131. :type column: int
  132. :param value: вредност на јазелот во зависност со кој се прави поделбата
  133. во дрвото за конкретната гранка
  134. :type value: int or float or str
  135. :return: поделени подмножества
  136. :rtype: list, list
  137. """
  138. # Направи функција која ни кажува дали редицата е во
  139. # првата група (True) или втората група (False)
  140. if isinstance(value, int) or isinstance(value, float):
  141. # ако вредноста за споредба е од тип int или float
  142. split_function = compare_numerical
  143. else:
  144. # ако вредноста за споредба е од друг тип (string)
  145. split_function = compare_nominal
  146.  
  147. # Подели ги редиците во две подмножества и врати ги
  148. # за секој ред за кој split_function враќа True
  149. set1 = [row for row in rows if
  150. split_function(row, column, value)]
  151. # за секој ред за кој split_function враќа False
  152. set2 = [row for row in rows if
  153. not split_function(row, column, value)]
  154. return set1, set2
  155.  
  156.  
  157. def build_tree(rows, scoref=entropy):
  158.  
  159. if len(rows) == 0:
  160. return DecisionNode()
  161. current_score = scoref(rows)
  162.  
  163. # променливи со кои следиме кој критериум е најдобар
  164. best_gain = 0.0
  165. best_criteria = None
  166. best_sets = None
  167.  
  168. column_count = len(rows[0]) - 1
  169. for col in range(0, column_count):
  170. # за секоја колона (col се движи во интервалот од 0 до
  171. # column_count - 1)
  172. # Следниов циклус е за генерирање на речник од различни
  173. # вредности во оваа колона
  174. column_values = {}
  175. for row in rows:
  176. column_values[row[col]] = 1
  177. # за секоја редица се зема вредноста во оваа колона и се
  178. # поставува како клуч во column_values
  179. for value in column_values.keys():
  180. (set1, set2) = divide_set(rows, col, value)
  181.  
  182. # Информациона добивка
  183. p = float(len(set1)) / len(rows)
  184. gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  185. if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  186. best_gain = gain
  187. best_criteria = (col, value)
  188. best_sets = (set1, set2)
  189.  
  190. # Креирај ги подгранките
  191. if best_gain > 0:
  192. true_branch = build_tree(best_sets[0], scoref)
  193. false_branch = build_tree(best_sets[1], scoref)
  194. return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  195. tb=true_branch, fb=false_branch)
  196. else:
  197. return DecisionNode(results=unique_counts(rows))
  198.  
  199.  
  200. def print_tree(tree, indent=''):
  201. # Дали е ова лист јазел?
  202. if tree.results:
  203. print(str(tree.results))
  204. else:
  205. # Се печати условот
  206. print(str(tree.col) + ':' + str(tree.value) + '? ')
  207. # Се печатат True гранките, па False гранките
  208. print(indent + 'T->', end='')
  209. print_tree(tree.tb, indent + ' ')
  210. print(indent + 'F->', end='')
  211. print_tree(tree.fb, indent + ' ')
  212.  
  213.  
  214. def classify(observation, tree):
  215. if tree.results:
  216. return tree.results
  217. else:
  218. value = observation[tree.col]
  219. if isinstance(value, int) or isinstance(value, float):
  220. compare = compare_numerical
  221. else:
  222. compare = compare_nominal
  223.  
  224. if compare(observation, tree.col, tree.value):
  225. branch = tree.tb
  226. else:
  227. branch = tree.fb
  228.  
  229. return classify(observation, branch)
  230.  
  231.  
  232. dataset = [[6.3, 2.3, 4.4, 1.3, 2],
  233. [6.4, 2.8, 5.6, 2.1, 0],
  234. [5.1, 3.3, 1.7, 0.5, 1],
  235. [5.1, 3.5, 1.4, 0.2, 1],
  236. [4.6, 3.1, 1.5, 0.2, 1],
  237. [5.8, 2.7, 5.1, 1.9, 0],
  238. [5.5, 3.5, 1.3, 0.2, 1],
  239. [5.7, 2.6, 3.5, 1.0, 2],
  240. [5.0, 3.5, 1.3, 0.3, 1],
  241. [6.3, 2.5, 5.0, 1.9, 0],
  242. [6.2, 2.2, 4.5, 1.5, 2],
  243. [5.0, 3.4, 1.6, 0.4, 1],
  244. [5.7, 4.4, 1.5, 0.4, 1],
  245. [4.9, 2.4, 3.3, 1.0, 2],
  246. [4.4, 2.9, 1.4, 0.2, 1],
  247. [5.5, 2.4, 3.7, 1.0, 2],
  248. [5.6, 2.5, 3.9, 1.1, 2],
  249. [5.6, 2.8, 4.9, 2.0, 0],
  250. [4.8, 3.4, 1.6, 0.2, 1],
  251. [5.6, 3.0, 4.5, 1.5, 2],
  252. [6.0, 3.0, 4.8, 1.8, 0],
  253. [6.3, 3.3, 4.7, 1.6, 2],
  254. [4.8, 3.0, 1.4, 0.1, 1],
  255. [7.9, 3.8, 6.4, 2.0, 0],
  256. [4.9, 3.0, 1.4, 0.2, 1],
  257. [4.3, 3.0, 1.1, 0.1, 1],
  258. [6.8, 3.2, 5.9, 2.3, 0],
  259. [5.6, 2.7, 4.2, 1.3, 2],
  260. [5.2, 4.1, 1.5, 0.1, 1],
  261. [6.2, 2.9, 4.3, 1.3, 2],
  262. [6.5, 2.8, 4.6, 1.5, 2],
  263. [5.4, 3.9, 1.3, 0.4, 1],
  264. [5.8, 2.6, 4.0, 1.2, 2],
  265. [5.4, 3.7, 1.5, 0.2, 1],
  266. [4.5, 2.3, 1.3, 0.3, 1],
  267. [6.3, 3.4, 5.6, 2.4, 0],
  268. [6.2, 3.4, 5.4, 2.3, 0],
  269. [5.7, 2.5, 5.0, 2.0, 0],
  270. [5.8, 2.7, 3.9, 1.2, 2],
  271. [6.4, 2.7, 5.3, 1.9, 0],
  272. [5.1, 3.8, 1.6, 0.2, 1],
  273. [6.3, 2.5, 4.9, 1.5, 2],
  274. [7.7, 2.8, 6.7, 2.0, 0],
  275. [5.1, 3.5, 1.4, 0.3, 1],
  276. [6.8, 2.8, 4.8, 1.4, 2],
  277. [6.1, 3.0, 4.6, 1.4, 2],
  278. [5.5, 4.2, 1.4, 0.2, 1],
  279. [5.0, 2.0, 3.5, 1.0, 2],
  280. [7.7, 3.0, 6.1, 2.3, 0],
  281. [5.1, 2.5, 3.0, 1.1, 2],
  282. [5.9, 3.0, 5.1, 1.8, 0],
  283. [7.2, 3.2, 6.0, 1.8, 0],
  284. [4.9, 3.1, 1.5, 0.2, 1],
  285. [5.7, 3.0, 4.2, 1.2, 2],
  286. [6.1, 2.9, 4.7, 1.4, 2],
  287. [5.0, 3.2, 1.2, 0.2, 1],
  288. [4.4, 3.2, 1.3, 0.2, 1],
  289. [6.7, 3.1, 5.6, 2.4, 0],
  290. [4.6, 3.6, 1.0, 0.2, 1],
  291. [5.1, 3.4, 1.5, 0.2, 1],
  292. [5.2, 2.7, 3.9, 1.4, 2],
  293. [6.4, 3.1, 5.5, 1.8, 0],
  294. [7.4, 2.8, 6.1, 1.9, 0],
  295. [4.9, 3.1, 1.5, 0.1, 1],
  296. [5.0, 3.5, 1.6, 0.6, 1],
  297. [6.7, 3.1, 4.7, 1.5, 2],
  298. [6.4, 3.2, 5.3, 2.3, 0],
  299. [6.3, 2.7, 4.9, 1.8, 0],
  300. [5.8, 4.0, 1.2, 0.2, 1],
  301. [6.9, 3.1, 5.4, 2.1, 0],
  302. [5.9, 3.2, 4.8, 1.8, 2],
  303. [6.6, 2.9, 4.6, 1.3, 2],
  304. [6.1, 2.8, 4.0, 1.3, 2],
  305. [7.7, 2.6, 6.9, 2.3, 0],
  306. [5.5, 2.6, 4.4, 1.2, 2],
  307. [6.3, 2.9, 5.6, 1.8, 0],
  308. [7.2, 3.0, 5.8, 1.6, 0],
  309. [6.5, 3.0, 5.8, 2.2, 0],
  310. [5.4, 3.9, 1.7, 0.4, 1],
  311. [6.5, 3.2, 5.1, 2.0, 0],
  312. [5.9, 3.0, 4.2, 1.5, 2],
  313. [5.1, 3.7, 1.5, 0.4, 1],
  314. [5.7, 2.8, 4.5, 1.3, 2],
  315. [5.4, 3.4, 1.5, 0.4, 1],
  316. [4.6, 3.4, 1.4, 0.3, 1],
  317. [4.9, 3.6, 1.4, 0.1, 1],
  318. [6.7, 2.5, 5.8, 1.8, 0],
  319. [5.0, 3.6, 1.4, 0.2, 1],
  320. [6.7, 3.3, 5.7, 2.5, 0],
  321. [4.4, 3.0, 1.3, 0.2, 1],
  322. [6.0, 2.2, 5.0, 1.5, 0],
  323. [6.0, 2.2, 4.0, 1.0, 2],
  324. [5.0, 3.4, 1.5, 0.2, 1],
  325. [5.7, 2.8, 4.1, 1.3, 2],
  326. [5.5, 2.4, 3.8, 1.1, 2],
  327. [5.1, 3.8, 1.9, 0.4, 1],
  328. [6.9, 3.1, 5.1, 2.3, 0],
  329. [5.6, 2.9, 3.6, 1.3, 2],
  330. [6.1, 2.8, 4.7, 1.2, 2],
  331. [5.5, 2.5, 4.0, 1.3, 2],
  332. [5.5, 2.3, 4.0, 1.3, 2],
  333. [6.0, 2.9, 4.5, 1.5, 2],
  334. [5.1, 3.8, 1.5, 0.3, 1],
  335. [5.7, 3.8, 1.7, 0.3, 1],
  336. [6.7, 3.3, 5.7, 2.1, 0],
  337. [4.8, 3.1, 1.6, 0.2, 1],
  338. [5.4, 3.0, 4.5, 1.5, 2],
  339. [6.5, 3.0, 5.2, 2.0, 0],
  340. [6.8, 3.0, 5.5, 2.1, 0],
  341. [7.6, 3.0, 6.6, 2.1, 0],
  342. [5.0, 3.0, 1.6, 0.2, 1],
  343. [6.7, 3.0, 5.0, 1.7, 2],
  344. [4.8, 3.4, 1.9, 0.2, 1],
  345. [5.8, 2.8, 5.1, 2.4, 0],
  346. [5.0, 2.3, 3.3, 1.0, 2],
  347. [4.8, 3.0, 1.4, 0.3, 1],
  348. [5.2, 3.5, 1.5, 0.2, 1],
  349. [6.1, 2.6, 5.6, 1.4, 0],
  350. [5.8, 2.7, 4.1, 1.0, 2],
  351. [6.9, 3.2, 5.7, 2.3, 0],
  352. [6.4, 2.9, 4.3, 1.3, 2],
  353. [7.3, 2.9, 6.3, 1.8, 0],
  354. [6.3, 2.8, 5.1, 1.5, 0],
  355. [6.2, 2.8, 4.8, 1.8, 0],
  356. [6.7, 3.1, 4.4, 1.4, 2],
  357. [6.0, 2.7, 5.1, 1.6, 2],
  358. [6.5, 3.0, 5.5, 1.8, 0],
  359. [6.1, 3.0, 4.9, 1.8, 0],
  360. [5.6, 3.0, 4.1, 1.3, 2],
  361. [4.7, 3.2, 1.6, 0.2, 1],
  362. [6.6, 3.0, 4.4, 1.4, 2]]
  363.  
  364. if __name__ == '__main__':
  365. x = input() .split(', ')
  366. test_case = list(map(float, x[:-1])) + [int(x[-1])]
  367.  
  368. n1 = int(len(dataset)*0.3)
  369. n2 = int(len(dataset )*0.6)
  370.  
  371. dataset1 = dataset[:n1]
  372. dataset2 = dataset[n1:n2]
  373. dataset3 = dataset[n2:]
  374.  
  375. tree1 = build_tree(dataset1)
  376. tree2 = build_tree(dataset2)
  377. tree3 = build_tree(dataset3)
  378.  
  379. lista = []
  380.  
  381. lista.append(classify(test_case, tree1).popitem()[0])
  382. lista.append(classify(test_case, tree2).popitem()[0])
  383. lista.append(classify(test_case, tree3).popitem()[0])
  384.  
  385. glasovi = [0,0,0]
  386.  
  387. for i in range(0,3):
  388. for j in range(0,3):
  389. if lista[j] == i:
  390. glasovi[i]+=1
  391.  
  392.  
  393. recnik = dict()
  394. max = glasovi[0]
  395. index = 0
  396.  
  397. for i in range(0,3):
  398. recnik[i] = glasovi[i]
  399. if max < glasovi[i]:
  400. index = i
  401. max = glasovi[i]
  402.  
  403. print("Glasovi: "+str(recnik))
  404. print("Predvidena klasa: "+str(index))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement