Advertisement
Guest User

Untitled

a guest
Nov 15th, 2019
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.80 KB | None | 0 0
  1. from copy import deepcopy
  2. from math import log10, log
  3.  
  4. log2 = lambda x: log10(x) / log10(2)
  5.  
  6.  
  7. def unique_counts(rows):
  8. """Креирај броење на можни резултати (последната колона
  9. во секоја редица е класата)
  10.  
  11. :param rows: dataset
  12. :type rows: list
  13. :return: dictionary of possible classes as keys and count
  14. as values
  15. :rtype: dict
  16. """
  17. results = {}
  18. for row in rows:
  19. # Клацата е последната колона
  20. r = row[len(row) - 1]
  21. if r not in results:
  22. results[r] = 0
  23. results[r] += 1
  24. return results
  25.  
  26.  
  27. def gini_impurity(rows):
  28. """Probability that a randomly placed item will
  29. be in the wrong category
  30.  
  31. :param rows: dataset
  32. :type rows: list
  33. :return: Gini impurity
  34. :rtype: float
  35. """
  36. total = len(rows)
  37. counts = unique_counts(rows)
  38. imp = 0
  39. for k1 in counts:
  40. p1 = float(counts[k1]) / total
  41. for k2 in counts:
  42. if k1 == k2:
  43. continue
  44. p2 = float(counts[k2]) / total
  45. imp += p1 * p2
  46. return imp
  47.  
  48.  
  49. def entropy(rows):
  50. """Ентропијата е сума од p(x)log(p(x)) за сите
  51. можни резултати
  52.  
  53. :param rows: податочно множество
  54. :type rows: list
  55. :return: вредност за ентропијата
  56. :rtype: float
  57. """
  58. log2 = lambda x: log(x) / log(2)
  59. results = unique_counts(rows)
  60. # Пресметка на ентропијата
  61. ent = 0.0
  62. for r in results.keys():
  63. p = float(results[r]) / len(rows)
  64. ent = ent - p * log2(p)
  65. return ent
  66.  
  67.  
  68. class DecisionNode:
  69. def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  70. """
  71. :param col: индексот на колоната (атрибутот) од тренинг множеството
  72. која се претставува со оваа инстанца т.е. со овој јазол
  73. :type col: int
  74. :param value: вредноста на јазолот според кој се дели дрвото
  75. :param results: резултати за тековната гранка, вредност (различна
  76. од None) само кај јазлите-листови во кои се донесува
  77. одлуката.
  78. :type results: dict
  79. :param tb: гранка која се дели од тековниот јазол кога вредноста е
  80. еднаква на value
  81. :type tb: DecisionNode
  82. :param fb: гранка која се дели од тековниот јазол кога вредноста е
  83. различна од value
  84. :type fb: DecisionNode
  85. """
  86. self.col = col
  87. self.value = value
  88. self.results = results
  89. self.tb = tb
  90. self.fb = fb
  91.  
  92.  
  93. def compare_numerical(row, column, value):
  94. """Споредба на вредноста од редицата на посакуваната колона со
  95. зададена нумеричка вредност
  96.  
  97. :param row: дадена редица во податочното множество
  98. :type row: list
  99. :param column: индекс на колоната (атрибутот) од тренирачкото множество
  100. :type column: int
  101. :param value: вредност на јазелот во согласност со кој се прави
  102. поделбата во дрвото
  103. :type value: int or float
  104. :return: True ако редицата >= value, инаку False
  105. :rtype: bool
  106. """
  107. return row[column] >= value
  108.  
  109.  
  110. def compare_nominal(row, column, value):
  111. """Споредба на вредноста од редицата на посакуваната колона со
  112. зададена номинална вредност
  113.  
  114. :param row: дадена редица во податочното множество
  115. :type row: list
  116. :param column: индекс на колоната (атрибутот) од тренирачкото множество
  117. :type column: int
  118. :param value: вредност на јазелот во согласност со кој се прави
  119. поделбата во дрвото
  120. :type value: str
  121. :return: True ако редицата == value, инаку False
  122. :rtype: bool
  123. """
  124. return row[column] == value
  125.  
  126.  
  127. def divide_set(rows, column, value):
  128. """Поделба на множеството според одредена колона. Може да се справи
  129. со нумерички или номинални вредности.
  130.  
  131. :param rows: тренирачко множество
  132. :type rows: list(list)
  133. :param column: индекс на колоната (атрибутот) од тренирачкото множество
  134. :type column: int
  135. :param value: вредност на јазелот во зависност со кој се прави поделбата
  136. во дрвото за конкретната гранка
  137. :type value: int or float or str
  138. :return: поделени подмножества
  139. :rtype: list, list
  140. """
  141. # Направи функција која ни кажува дали редицата е во
  142. # првата група (True) или втората група (False)
  143. if isinstance(value, int) or isinstance(value, float):
  144. # ако вредноста за споредба е од тип int или float
  145. split_function = compare_numerical
  146. else:
  147. # ако вредноста за споредба е од друг тип (string)
  148. split_function = compare_nominal
  149.  
  150. # Подели ги редиците во две подмножества и врати ги
  151. # за секој ред за кој split_function враќа True
  152. set1 = [row for row in rows if
  153. split_function(row, column, value)]
  154. # set1 = []
  155. # for row in rows:
  156. # if not split_function(row, column, value):
  157. # set1.append(row)
  158. # за секој ред за кој split_function враќа False
  159. set2 = [row for row in rows if
  160. not split_function(row, column, value)]
  161. return set1, set2
  162.  
  163.  
  164. def build_tree(rows, scoref=entropy):
  165. """Градење на дрво на одлука.
  166.  
  167. :param rows: тренирачко множество
  168. :type rows: list(list)
  169. :param scoref: функција за одбирање на најдобар атрибут во даден чекор
  170. :type scoref: function
  171. :return: коренот на изграденото дрво на одлука
  172. :rtype: DecisionNode object
  173. """
  174. if len(rows) == 0:
  175. return DecisionNode()
  176. current_score = scoref(rows)
  177.  
  178. # променливи со кои следиме кој критериум е најдобар
  179. best_gain = 0.0
  180. best_criteria = None
  181. best_sets = None
  182.  
  183. column_count = len(rows[0]) - 1
  184. for col in range(0, column_count):
  185. # за секоја колона (col се движи во интервалот од 0 до
  186. # column_count - 1)
  187. # Следниов циклус е за генерирање на речник од различни
  188. # вредности во оваа колона
  189. column_values = {}
  190. for row in rows:
  191. column_values[row[col]] = 1
  192. # за секоја редица се зема вредноста во оваа колона и се
  193. # поставува како клуч во column_values
  194. for value in column_values.keys():
  195. (set1, set2) = divide_set(rows, col, value)
  196.  
  197. # Информациона добивка
  198. p = float(len(set1)) / len(rows)
  199. gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  200. if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  201. best_gain = gain
  202. best_criteria = (col, value)
  203. best_sets = (set1, set2)
  204.  
  205. # Креирај ги подгранките
  206. if best_gain > 0:
  207. true_branch = build_tree(best_sets[0], scoref)
  208. false_branch = build_tree(best_sets[1], scoref)
  209. return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  210. tb=true_branch, fb=false_branch)
  211. else:
  212. return DecisionNode(results=unique_counts(rows))
  213.  
  214.  
  215. def print_tree(tree, indent=''):
  216. """Принтање на дрво на одлука
  217.  
  218. :param tree: коренот на дрвото на одлучување
  219. :type tree: DecisionNode object
  220. :param indent:
  221. :return: None
  222. """
  223. # Дали е ова лист јазел?
  224. if tree.results:
  225. print(str(tree.results))
  226. else:
  227. # Се печати условот
  228. print(str(tree.col) + ':' + str(tree.value) + '? ')
  229. # Се печатат True гранките, па False гранките
  230. print(indent + 'T->', end='')
  231. print_tree(tree.tb, indent + ' ')
  232. print(indent + 'F->', end='')
  233. print_tree(tree.fb, indent + ' ')
  234.  
  235.  
  236. def classify(observation, tree):
  237. """Класификација на нов податочен примерок со изградено дрво на одлука
  238.  
  239. :param observation: еден ред од податочното множество за предвидување
  240. :type observation: list
  241. :param tree: коренот на дрвото на одлучување
  242. :type tree: DecisionNode object
  243. :return: речник со класите како клуч и бројот на појавување во листот на дрвото
  244. за класификација како вредност во речникот
  245. :rtype: dict
  246. """
  247. if tree.results:
  248. return tree.results
  249. else:
  250. value = observation[tree.col]
  251. if isinstance(value, int) or isinstance(value, float):
  252. compare = compare_numerical
  253. else:
  254. compare = compare_nominal
  255.  
  256. if compare(observation, tree.col, tree.value):
  257. branch = tree.tb
  258. else:
  259. branch = tree.fb
  260.  
  261. return classify(observation, branch)
  262.  
  263.  
  264. dataset = [
  265. [6.3, 2.9, 5.6, 1.8, 0],
  266. [6.5, 3.0, 5.8, 2.2, 0],
  267. [7.6, 3.0, 6.6, 2.1, 0],
  268. [4.9, 2.5, 4.5, 1.7, 0],
  269. [7.3, 2.9, 6.3, 1.8, 0],
  270. [6.7, 2.5, 5.8, 1.8, 0],
  271. [7.2, 3.6, 6.1, 2.5, 0],
  272. [6.5, 3.2, 5.1, 2.0, 0],
  273. [6.4, 2.7, 5.3, 1.9, 0],
  274. [6.8, 3.0, 5.5, 2.1, 0],
  275. [5.7, 2.5, 5.0, 2.0, 0],
  276. [5.8, 2.8, 5.1, 2.4, 0],
  277. [6.4, 3.2, 5.3, 2.3, 0],
  278. [6.5, 3.0, 5.5, 1.8, 0],
  279. [7.7, 3.8, 6.7, 2.2, 0],
  280. [7.7, 2.6, 6.9, 2.3, 0],
  281. [6.0, 2.2, 5.0, 1.5, 0],
  282. [6.9, 3.2, 5.7, 2.3, 0],
  283. [5.6, 2.8, 4.9, 2.0, 0],
  284. [7.7, 2.8, 6.7, 2.0, 0],
  285. [6.3, 2.7, 4.9, 1.8, 0],
  286. [6.7, 3.3, 5.7, 2.1, 0],
  287. [7.2, 3.2, 6.0, 1.8, 0],
  288. [6.2, 2.8, 4.8, 1.8, 0],
  289. [6.1, 3.0, 4.9, 1.8, 0],
  290. [6.4, 2.8, 5.6, 2.1, 0],
  291. [7.2, 3.0, 5.8, 1.6, 0],
  292. [7.4, 2.8, 6.1, 1.9, 0],
  293. [7.9, 3.8, 6.4, 2.0, 0],
  294. [6.4, 2.8, 5.6, 2.2, 0],
  295. [6.3, 2.8, 5.1, 1.5, 0],
  296. [6.1, 2.6, 5.6, 1.4, 0],
  297. [7.7, 3.0, 6.1, 2.3, 0],
  298. [6.3, 3.4, 5.6, 2.4, 0],
  299. [5.1, 3.5, 1.4, 0.2, 1],
  300. [4.9, 3.0, 1.4, 0.2, 1],
  301. [4.7, 3.2, 1.3, 0.2, 1],
  302. [4.6, 3.1, 1.5, 0.2, 1],
  303. [5.0, 3.6, 1.4, 0.2, 1],
  304. [5.4, 3.9, 1.7, 0.4, 1],
  305. [4.6, 3.4, 1.4, 0.3, 1],
  306. [5.0, 3.4, 1.5, 0.2, 1],
  307. [4.4, 2.9, 1.4, 0.2, 1],
  308. [4.9, 3.1, 1.5, 0.1, 1],
  309. [5.4, 3.7, 1.5, 0.2, 1],
  310. [4.8, 3.4, 1.6, 0.2, 1],
  311. [4.8, 3.0, 1.4, 0.1, 1],
  312. [4.3, 3.0, 1.1, 0.1, 1],
  313. [5.8, 4.0, 1.2, 0.2, 1],
  314. [5.7, 4.4, 1.5, 0.4, 1],
  315. [5.4, 3.9, 1.3, 0.4, 1],
  316. [5.1, 3.5, 1.4, 0.3, 1],
  317. [5.7, 3.8, 1.7, 0.3, 1],
  318. [5.1, 3.8, 1.5, 0.3, 1],
  319. [5.4, 3.4, 1.7, 0.2, 1],
  320. [5.1, 3.7, 1.5, 0.4, 1],
  321. [4.6, 3.6, 1.0, 0.2, 1],
  322. [5.1, 3.3, 1.7, 0.5, 1],
  323. [4.8, 3.4, 1.9, 0.2, 1],
  324. [5.0, 3.0, 1.6, 0.2, 1],
  325. [5.0, 3.4, 1.6, 0.4, 1],
  326. [5.2, 3.5, 1.5, 0.2, 1],
  327. [5.2, 3.4, 1.4, 0.2, 1],
  328. [5.5, 2.3, 4.0, 1.3, 2],
  329. [6.5, 2.8, 4.6, 1.5, 2],
  330. [5.7, 2.8, 4.5, 1.3, 2],
  331. [6.3, 3.3, 4.7, 1.6, 2],
  332. [4.9, 2.4, 3.3, 1.0, 2],
  333. [6.6, 2.9, 4.6, 1.3, 2],
  334. [5.2, 2.7, 3.9, 1.4, 2],
  335. [5.0, 2.0, 3.5, 1.0, 2],
  336. [5.9, 3.0, 4.2, 1.5, 2],
  337. [6.0, 2.2, 4.0, 1.0, 2],
  338. [6.1, 2.9, 4.7, 1.4, 2],
  339. [5.6, 2.9, 3.6, 1.3, 2],
  340. [6.7, 3.1, 4.4, 1.4, 2],
  341. [5.6, 3.0, 4.5, 1.5, 2],
  342. [5.8, 2.7, 4.1, 1.0, 2],
  343. [6.2, 2.2, 4.5, 1.5, 2],
  344. [5.6, 2.5, 3.9, 1.1, 2],
  345. [5.9, 3.2, 4.8, 1.8, 2],
  346. [6.1, 2.8, 4.0, 1.3, 2],
  347. [6.3, 2.5, 4.9, 1.5, 2],
  348. [6.1, 2.8, 4.7, 1.2, 2],
  349. [6.4, 2.9, 4.3, 1.3, 2],
  350. [6.6, 3.0, 4.4, 1.4, 2],
  351. [6.8, 2.8, 4.8, 1.4, 2],
  352. [6.7, 3.0, 5.0, 1.7, 2],
  353. [6.0, 2.9, 4.5, 1.5, 2],
  354. [5.7, 2.6, 3.5, 1.0, 2],
  355. [5.5, 2.4, 3.8, 1.1, 2],
  356. [5.4, 3.0, 4.5, 1.5, 2],
  357. [6.0, 3.4, 4.5, 1.6, 2],
  358. [6.7, 3.1, 4.7, 1.5, 2],
  359. [6.3, 2.3, 4.4, 1.3, 2],
  360. [5.6, 3.0, 4.1, 1.3, 2],
  361. [5.5, 2.5, 4.0, 1.3, 2],
  362. [5.5, 2.6, 4.4, 1.2, 2],
  363. [6.1, 3.0, 4.6, 1.4, 2],
  364. [5.8, 2.6, 4.0, 1.2, 2],
  365. [5.0, 2.3, 3.3, 1.0, 2],
  366. [5.6, 2.7, 4.2, 1.3, 2],
  367. [5.7, 3.0, 4.2, 1.2, 2],
  368. [5.7, 2.9, 4.2, 1.3, 2],
  369. [6.2, 2.9, 4.3, 1.3, 2],
  370. [5.1, 2.5, 3.0, 1.1, 2],
  371. [5.7, 2.8, 4.1, 1.3, 2],
  372. [6.4, 3.1, 5.5, 1.8, 0],
  373. [6.0, 3.0, 4.8, 1.8, 0],
  374. [6.9, 3.1, 5.4, 2.1, 0],
  375. [6.8, 3.2, 5.9, 2.3, 0],
  376. [6.7, 3.3, 5.7, 2.5, 0],
  377. [6.7, 3.0, 5.2, 2.3, 0],
  378. [6.3, 2.5, 5.0, 1.9, 0],
  379. [6.5, 3.0, 5.2, 2.0, 0],
  380. [6.2, 3.4, 5.4, 2.3, 0],
  381. [4.7, 3.2, 1.6, 0.2, 1],
  382. [4.8, 3.1, 1.6, 0.2, 1],
  383. [5.4, 3.4, 1.5, 0.4, 1],
  384. [5.2, 4.1, 1.5, 0.1, 1],
  385. [5.5, 4.2, 1.4, 0.2, 1],
  386. [4.9, 3.1, 1.5, 0.2, 1],
  387. [5.0, 3.2, 1.2, 0.2, 1],
  388. [5.5, 3.5, 1.3, 0.2, 1],
  389. [4.9, 3.6, 1.4, 0.1, 1],
  390. [4.4, 3.0, 1.3, 0.2, 1],
  391. [5.1, 3.4, 1.5, 0.2, 1],
  392. [5.0, 3.5, 1.3, 0.3, 1],
  393. [4.5, 2.3, 1.3, 0.3, 1],
  394. [4.4, 3.2, 1.3, 0.2, 1],
  395. [5.0, 3.5, 1.6, 0.6, 1],
  396. [5.9, 3.0, 5.1, 1.8, 0],
  397. [5.1, 3.8, 1.9, 0.4, 1],
  398. [4.8, 3.0, 1.4, 0.3, 1],
  399. [5.1, 3.8, 1.6, 0.2, 1],
  400. [5.5, 2.4, 3.7, 1.0, 2],
  401. [5.8, 2.7, 3.9, 1.2, 2],
  402. [6.0, 2.7, 5.1, 1.6, 2],
  403. [6.7, 3.1, 5.6, 2.4, 0],
  404. [6.9, 3.1, 5.1, 2.3, 0],
  405. [5.8, 2.7, 5.1, 1.9, 0],
  406. ]
  407.  
  408. if __name__ == "__main__":
  409. column_ind = int(input())
  410. procent = len(dataset) * 80 / 100
  411.  
  412. data1 = dataset[:int(procent)]
  413. data2 = deepcopy(data1)
  414.  
  415. for row in data2:
  416. del row[column_ind]
  417.  
  418.  
  419. t1 = build_tree(data1)
  420. t2 = build_tree(data2)
  421.  
  422. dataTest = dataset[int(procent):]
  423. counter1 = 0
  424. counter2 = 0
  425.  
  426. for row in dataTest:
  427. dolz = len(row)
  428. l = row[:dolz - 1]
  429. drvo1 = classify(l, t1)
  430.  
  431. for key in drvo1.keys():
  432. if key == row[len(row)-1]:
  433. counter1 += 1
  434.  
  435.  
  436. data3 = deepcopy(dataset[int(procent):])
  437. for row in data3:
  438. del row[column_ind]
  439. for row in data3:
  440. dolz = len(row)
  441. l = row[:dolz - 1]
  442. drvo2 = classify(l, t2)
  443.  
  444. for key in drvo2.keys():
  445. if key == row[len(row)-1]:
  446. counter2 += 1
  447.  
  448.  
  449.  
  450. print("Tochnost so prvoto drvo na odluka:", counter1 / len(dataTest))
  451. print("Tochnost so vtoroto drvo na odluka:", counter2 / len(dataTest))
  452.  
  453. """for key in drvo2.keys():
  454. if key == row[len(row)-1]:
  455. counter2+=1
  456. print("tocno e")
  457.  
  458.  
  459. # Vashiot kod tuka
  460. """
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement