Advertisement
glavinova

[СНЗ] Дрва за одлучување

Jul 5th, 2020
1,065
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.48 KB | None | 0 0
  1. """Дрва за одлучување Problem 2 (3 / 3)
  2. Дадено е податочно множество од риби кое е опишано со следните атрибути:
  3.  
  4.    0  Weight      Weight of the fish (in grams)
  5.    1  Length1     Length from the nose to the beginning of the tail (in cm)
  6.    2  Length2     Length from the nose to the notch of the tail (in cm)
  7.    3  Length3     Length from the nose to the end of the tail (in cm)
  8.    4  Height%     Maximal height as % of Length3
  9.    5  Width%      Maximal width as % of Length3
  10.    6  Class       Fish Species
  11. Класата (видот на рибата) е дадена во последната колона.
  12.  
  13. Да се направи модел за класификација за даденото податочно множество. За тренинг да се земат само првите 40 примероци од класите Roach и Pike (Првите 40 од Roach и првите 40 од Pike). Значи треба да се направи бинарен класификатор, при што секоја класа освен _Roach_ и _Pike_ се игнорира при тренирањето на множеството. Притоа земањето на првите 40 примероци да се направи во програмата, а не со рачно копирање! Ако во множеството има помалку од 40 примероци од дадена класа, се земаат онолку колку што има (т.е. сите што ги има).
  14.  
  15. Да се класифицира елементот даден на влез и да се испечати предвидувањето."""
  16.  
  17.  
  18. from math import log
  19. data=[[180.0, 23.6, 25.2, 27.9, 25.4, 14.0, 'Roach'],
  20.  [12.2, 11.5, 12.2, 13.4, 15.6, 10.4, 'Smelt'],
  21.  [135.0, 20.0, 22.0, 23.5, 25.0, 15.0, 'Perch'],
  22.  [1600.0, 56.0, 60.0, 64.0, 15.0, 9.6, 'Pike'],
  23.  [120.0, 20.0, 22.0, 23.5, 26.0, 14.5, 'Perch'],
  24.  [273.0, 23.0, 25.0, 28.0, 39.6, 14.8, 'Silver Bream'],
  25.  [320.0, 27.8, 30.0, 31.6, 24.1, 15.1, 'Perch'],
  26.  [160.0, 21.1, 22.5, 25.0, 25.6, 15.2, 'Roach'],
  27.  [700.0, 30.4, 33.0, 38.3, 38.8, 13.8, 'Bream'],
  28.  [500.0, 29.5, 32.0, 37.3, 37.3, 13.6, 'Bream'],
  29.  [290.0, 24.0, 26.3, 31.2, 40.0, 13.8, 'Bream'],
  30.  [650.0, 31.0, 33.5, 38.7, 37.4, 14.8, 'Bream'],
  31.  [500.0, 26.8, 29.7, 34.5, 41.1, 15.3, 'Bream'],
  32.  [260.0, 25.4, 27.5, 28.9, 24.8, 15.0, 'Perch'],
  33.  [80.0, 17.2, 19.0, 20.2, 27.9, 15.1, 'Perch'],
  34.  [850.0, 32.8, 36.0, 41.6, 40.6, 14.9, 'Bream'],
  35.  [345.0, 36.0, 38.5, 41.0, 15.6, 9.7, 'Pike'],
  36.  [567.0, 43.2, 46.0, 48.7, 16.0, 10.0, 'Pike'],
  37.  [55.0, 13.5, 14.7, 16.5, 41.5, 14.1, 'Silver Bream'],
  38.  [78.0, 16.8, 18.7, 19.4, 26.8, 16.1, 'Perch'],
  39.  [950.0, 38.0, 41.0, 46.5, 37.9, 13.7, 'Bream'],
  40.  [306.0, 25.6, 28.0, 30.8, 28.5, 15.2, 'Whitewish'],
  41.  [6.7, 9.3, 9.8, 10.8, 16.1, 9.7, 'Smelt'],
  42.  [714.0, 32.7, 36.0, 41.5, 39.8, 14.1, 'Bream'],
  43.  [197.0, 23.5, 25.6, 27.0, 24.3, 15.7, 'Perch'],
  44.  [1000.0, 41.1, 44.0, 46.6, 26.8, 16.3, 'Perch'],
  45.  [685.0, 34.0, 36.5, 39.0, 27.9, 17.6, 'Perch'],
  46.  [169.0, 22.0, 24.0, 27.2, 27.7, 14.1, 'Roach'],
  47.  [125.0, 19.0, 21.0, 22.5, 25.3, 16.3, 'Perch'],
  48.  [1000.0, 33.5, 37.0, 42.6, 44.5, 15.5, 'Bream'],
  49.  [900.0, 36.5, 39.0, 41.4, 26.9, 18.1, 'Perch'],
  50.  [19.7, 13.2, 14.3, 15.2, 18.9, 13.6, 'Smelt'],
  51.  [150.0, 20.4, 22.0, 24.7, 23.5, 15.2, 'Roach'],
  52.  [120.0, 17.5, 19.0, 21.3, 39.4, 13.7, 'Silver Bream'],
  53.  [140.0, 19.0, 20.7, 23.2, 36.8, 14.2, 'Silver Bream'],
  54.  [290.0, 24.0, 26.0, 29.2, 30.4, 15.4, 'Roach'],
  55.  [725.0, 31.8, 35.0, 40.9, 40.0, 14.8, 'Bream'],
  56.  [1000.0, 40.2, 43.5, 46.0, 27.4, 17.7, 'Perch'],
  57.  [188.0, 22.6, 24.6, 26.2, 25.7, 15.9, 'Perch'],
  58.  [242.0, 23.2, 25.4, 30.0, 38.4, 13.4, 'Bream'],
  59.  [475.0, 28.4, 31.0, 36.2, 39.4, 14.1, 'Bream'],
  60.  [700.0, 30.4, 33.0, 38.5, 38.8, 13.5, 'Bream'],
  61.  [120.0, 18.6, 20.0, 22.2, 28.0, 16.1, 'Roach'],
  62.  [820.0, 36.6, 39.0, 41.3, 30.1, 17.8, 'Perch'],
  63.  [540.0, 28.5, 31.0, 34.0, 31.6, 19.3, 'Whitewish'],
  64.  [150.0, 20.5, 22.5, 24.0, 28.3, 15.1, 'Perch'],
  65.  [161.0, 22.0, 23.4, 26.7, 25.9, 13.6, 'Roach'],
  66.  [60.0, 14.3, 15.5, 17.4, 37.8, 13.3, 'Silver Bream'],
  67.  [840.0, 32.5, 35.0, 37.3, 30.8, 20.9, 'Perch'],
  68.  [300.0, 24.0, 26.0, 29.0, 39.2, 14.6, 'Silver Bream'],
  69.  [300.0, 25.2, 27.3, 28.7, 29.0, 17.9, 'Perch'],
  70.  [180.0, 23.0, 25.0, 26.5, 24.3, 13.9, 'Perch'],
  71.  [85.0, 18.2, 20.0, 21.0, 24.2, 13.2, 'Perch'],
  72.  [130.0, 20.5, 22.5, 24.0, 24.4, 15.1, 'Perch'],
  73.  [900.0, 37.0, 40.0, 42.5, 27.6, 17.0, 'Perch'],
  74.  [9.9, 11.3, 11.8, 13.1, 16.9, 8.9, 'Smelt'],
  75.  [620.0, 31.5, 34.5, 39.7, 39.1, 13.3, 'Bream'],
  76.  [720.0, 32.0, 35.0, 40.6, 40.3, 15.0, 'Bream'],
  77.  [270.0, 23.6, 26.0, 28.7, 29.2, 14.8, 'Whitewish'],
  78.  [40.0, 13.8, 15.0, 16.0, 23.9, 15.2, 'Perch'],
  79.  [5.9, 7.5, 8.4, 8.8, 24.0, 16.0, 'Perch'],
  80.  [115.0, 19.0, 21.0, 22.5, 26.3, 14.7, 'Perch'],
  81.  [110.0, 20.0, 22.0, 23.5, 23.5, 17.0, 'Perch'],
  82.  [300.0, 26.9, 28.7, 30.1, 25.2, 15.4, 'Perch'],
  83.  [363.0, 26.3, 29.0, 33.5, 38.0, 13.3, 'Bream'],
  84.  [690.0, 34.6, 37.0, 39.3, 26.9, 16.2, 'Perch'],
  85.  [820.0, 37.1, 40.0, 42.5, 26.2, 15.6, 'Perch'],
  86.  [19.9, 13.8, 15.0, 16.2, 18.1, 11.6, 'Smelt'],
  87.  [40.0, 12.9, 14.1, 16.2, 25.6, 14.0, 'Roach'],
  88.  [390.0, 27.6, 30.0, 35.0, 36.2, 13.4, 'Bream'],
  89.  [1250.0, 52.0, 56.0, 59.7, 17.9, 11.7, 'Pike'],
  90.  [87.0, 18.2, 19.8, 22.2, 25.3, 14.3, 'Roach'],
  91.  [9.8, 10.7, 11.2, 12.4, 16.8, 10.3, 'Smelt'],
  92.  [13.4, 11.7, 12.4, 13.5, 18.0, 9.4, 'Smelt'],
  93.  [975.0, 37.4, 41.0, 45.9, 40.6, 14.7, 'Bream'],
  94.  [1100.0, 39.0, 42.0, 44.6, 28.7, 15.4, 'Perch'],
  95.  [130.0, 20.0, 22.0, 23.5, 26.0, 15.0, 'Perch'],
  96.  [450.0, 27.6, 30.0, 35.1, 39.9, 13.8, 'Bream'],
  97.  [200.0, 30.0, 32.3, 34.8, 16.0, 9.7, 'Pike'],
  98.  [340.0, 23.9, 26.5, 31.1, 39.8, 15.1, 'Bream'],
  99.  [700.0, 34.0, 36.0, 38.3, 27.7, 17.6, 'Perch'],
  100.  [170.0, 21.5, 23.5, 25.0, 25.1, 14.9, 'Perch'],
  101.  [500.0, 29.1, 31.5, 36.4, 37.8, 12.0, 'Bream'],
  102.  [150.0, 18.4, 20.0, 22.4, 39.7, 14.7, 'Silver Bream'],
  103.  [145.0, 20.7, 22.7, 24.2, 24.6, 15.0, 'Perch'],
  104.  [85.0, 17.8, 19.6, 20.8, 24.7, 14.6, 'Perch'],
  105.  [600.0, 29.4, 32.0, 37.2, 40.2, 13.9, 'Bream'],
  106.  [300.0, 34.8, 37.3, 39.8, 15.8, 10.1, 'Pike'],
  107.  [456.0, 40.0, 42.5, 45.5, 16.0, 9.5, 'Pike'],
  108.  [540.0, 40.1, 43.0, 45.8, 17.0, 11.2, 'Pike'],
  109.  [12.2, 12.1, 13.0, 13.8, 16.5, 9.1, 'Smelt'],
  110.  [100.0, 16.2, 18.0, 19.2, 27.2, 17.3, 'Perch'],
  111.  [300.0, 32.7, 35.0, 38.8, 15.3, 11.3, 'Pike'],
  112.  [700.0, 31.9, 35.0, 40.5, 40.1, 13.8, 'Bream'],
  113.  [610.0, 30.9, 33.5, 38.6, 40.5, 13.3, 'Bream'],
  114.  [700.0, 34.5, 37.0, 39.4, 27.5, 15.9, 'Perch'],
  115.  [70.0, 15.7, 17.4, 18.5, 24.8, 15.9, 'Perch'],
  116.  [955.0, 35.0, 38.5, 44.0, 41.1, 14.3, 'Bream'],
  117.  [514.0, 30.5, 32.8, 34.0, 29.5, 17.7, 'Perch'],
  118.  [51.5, 15.0, 16.2, 17.2, 26.7, 15.3, 'Perch'],
  119.  [272.0, 25.0, 27.0, 30.6, 28.0, 15.6, 'Roach'],
  120.  [500.0, 28.5, 30.7, 36.2, 39.3, 13.7, 'Bream'],
  121.  [9.8, 11.4, 12.0, 13.2, 16.7, 8.7, 'Smelt'],
  122.  [510.0, 40.0, 42.5, 45.5, 15.0, 9.8, 'Pike'],
  123.  [925.0, 36.2, 39.5, 45.3, 41.4, 14.9, 'Bream'],
  124.  [1015.0, 37.0, 40.0, 42.4, 29.2, 17.6, 'Perch'],
  125.  [1550.0, 56.0, 60.0, 64.0, 15.0, 9.6, 'Pike'],
  126.  [1000.0, 37.3, 40.0, 43.5, 28.4, 15.0, 'Whitewish'],
  127.  [920.0, 35.0, 38.5, 44.1, 40.9, 14.3, 'Bream'],
  128.  [140.0, 21.0, 22.5, 25.0, 26.2, 13.3, 'Roach'],
  129.  [218.0, 25.0, 26.5, 28.0, 25.6, 14.8, 'Perch'],
  130.  [9.7, 10.4, 11.0, 12.0, 18.3, 11.5, 'Smelt'],
  131.  [69.0, 16.5, 18.2, 20.3, 26.1, 13.9, 'Roach'],
  132.  [110.0, 19.0, 21.0, 22.5, 25.3, 15.8, 'Perch'],
  133.  [150.0, 21.0, 23.0, 24.5, 21.3, 14.8, 'Perch'],
  134.  [160.0, 20.5, 22.5, 25.3, 27.8, 15.1, 'Roach'],
  135.  [7.0, 10.1, 10.6, 11.6, 14.9, 9.9, 'Smelt'],
  136.  [78.0, 17.5, 18.8, 21.2, 26.3, 13.7, 'Roach'],
  137.  [450.0, 26.8, 29.7, 34.7, 39.2, 14.2, 'Bream'],
  138.  [556.0, 32.0, 34.5, 36.5, 28.1, 17.5, 'Perch'],
  139.  [1650.0, 59.0, 63.4, 68.0, 15.9, 11.0, 'Pike'],
  140.  [110.0, 19.1, 20.8, 23.1, 26.7, 14.7, 'Roach'],
  141.  [685.0, 31.4, 34.0, 39.2, 40.8, 13.7, 'Bream'],
  142.  [200.0, 22.1, 23.5, 26.8, 27.6, 15.4, 'Roach'],
  143.  [770.0, 44.8, 48.0, 51.2, 15.0, 10.5, 'Pike'],
  144.  [7.5, 10.0, 10.5, 11.6, 17.0, 10.0, 'Smelt'],
  145.  [8.7, 10.8, 11.3, 12.6, 15.7, 10.2, 'Smelt'],
  146.  [500.0, 42.0, 45.0, 48.0, 14.5, 10.2, 'Pike'],
  147.  [170.0, 19.0, 20.7, 23.2, 40.5, 14.7, 'Silver Bream'],
  148.  [120.0, 20.0, 22.0, 23.5, 24.0, 15.0, 'Perch'],
  149.  [145.0, 19.8, 21.5, 24.1, 40.4, 13.1, 'Silver Bream'],
  150.  [130.0, 19.3, 21.3, 22.8, 28.0, 15.5, 'Perch'],
  151.  [850.0, 36.9, 40.0, 42.3, 28.2, 16.8, 'Perch'],
  152.  [265.0, 25.4, 27.5, 28.9, 24.4, 15.0, 'Perch'],
  153.  [0.0, 19.0, 20.5, 22.8, 28.4, 14.7, 'Roach'],
  154.  [680.0, 31.8, 35.0, 40.6, 38.1, 15.1, 'Bream'],
  155.  [90.0, 16.3, 17.7, 19.8, 37.4, 13.5, 'Silver Bream'],
  156.  [575.0, 31.3, 34.0, 39.5, 38.3, 14.1, 'Bream'],
  157.  [390.0, 29.5, 31.7, 35.0, 27.1, 15.3, 'Roach'],
  158.  [225.0, 22.0, 24.0, 25.5, 28.6, 14.6, 'Perch'],
  159.  [10.0, 11.3, 11.8, 13.1, 16.9, 9.8, 'Smelt'],
  160.  [1000.0, 39.8, 43.0, 45.2, 26.4, 16.1, 'Perch'],
  161.  [500.0, 28.7, 31.0, 36.2, 39.7, 13.3, 'Bream'],
  162.  [120.0, 19.4, 21.0, 23.7, 25.8, 13.9, 'Roach'],
  163.  [430.0, 35.5, 38.0, 40.5, 18.0, 11.3, 'Pike'],
  164.  [200.0, 21.2, 23.0, 25.8, 40.1, 14.2, 'Silver Bream'],
  165.  [250.0, 25.9, 28.0, 29.4, 26.6, 14.3, 'Perch'],
  166.  [800.0, 33.7, 36.4, 39.6, 29.7, 16.6, 'Whitewish'],
  167.  [32.0, 12.5, 13.7, 14.7, 24.0, 13.6, 'Perch'],
  168.  [430.0, 26.5, 29.0, 34.0, 36.6, 15.1, 'Bream'],
  169.  [145.0, 20.5, 22.0, 24.3, 27.3, 14.6, 'Roach'],
  170.  [950.0, 48.3, 51.7, 55.1, 16.2, 11.2, 'Pike'],
  171.  [300.0, 31.7, 34.0, 37.8, 15.1, 11.0, 'Pike'],
  172.  [250.0, 25.4, 27.5, 28.9, 25.2, 15.8, 'Perch'],
  173.  [650.0, 36.5, 39.0, 41.4, 26.9, 14.5, 'Perch'],
  174.  [270.0, 24.1, 26.5, 29.3, 27.8, 14.5, 'Whitewish'],
  175.  [600.0, 29.4, 32.0, 37.2, 41.5, 15.0, 'Bream'],
  176.  [145.0, 22.0, 24.0, 25.5, 25.0, 15.0, 'Perch'],
  177.  [1100.0, 40.1, 43.0, 45.5, 27.5, 16.3, 'Perch']]
  178.  
  179. from math import log
  180.  
  181.  
  182. def unique_counts(rows):
  183.     """Креирај броење на можни резултати (последната колона
  184.    во секоја редица е класата)
  185.  
  186.    :param rows: dataset
  187.    :type rows: list
  188.    :return: dictionary of possible classes as keys and count
  189.             as values
  190.    :rtype: dict
  191.    """
  192.     results = {}
  193.     for row in rows:
  194.         # Клацата е последната колона
  195.         r = row[len(row) - 1]
  196.         if r not in results:
  197.             results[r] = 0
  198.         results[r] += 1
  199.     return results
  200.  
  201.  
  202. def gini_impurity(rows):
  203.     """Probability that a randomly placed item will
  204.    be in the wrong category
  205.  
  206.    :param rows: dataset
  207.    :type rows: list
  208.    :return: Gini impurity
  209.    :rtype: float
  210.    """
  211.     total = len(rows)
  212.     counts = unique_counts(rows)
  213.     imp = 0
  214.     for k1 in counts:
  215.         p1 = float(counts[k1]) / total
  216.         for k2 in counts:
  217.             if k1 == k2:
  218.                 continue
  219.             p2 = float(counts[k2]) / total
  220.             imp += p1 * p2
  221.     return imp
  222.  
  223.  
  224. def entropy(rows):
  225.     """Ентропијата е сума од p(x)log(p(x)) за сите
  226.    можни резултати
  227.  
  228.    :param rows: податочно множество
  229.    :type rows: list
  230.    :return: вредност за ентропијата
  231.    :rtype: float
  232.    """
  233.     log2 = lambda x: log(x) / log(2)
  234.     results = unique_counts(rows)
  235.     # Пресметка на ентропијата
  236.     ent = 0.0
  237.     for r in results.keys():
  238.         p = float(results[r]) / len(rows)
  239.         ent = ent - p * log2(p)
  240.     return ent
  241.  
  242.  
  243. class DecisionNode:
  244.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  245.         """
  246.        :param col: индексот на колоната (атрибутот) од тренинг множеството
  247.                    која се претставува со оваа инстанца т.е. со овој јазол
  248.        :type col: int
  249.        :param value: вредноста на јазолот според кој се дели дрвото
  250.        :param results: резултати за тековната гранка, вредност (различна
  251.                        од None) само кај јазлите-листови во кои се донесува
  252.                        одлуката.
  253.        :type results: dict
  254.        :param tb: гранка која се дели од тековниот јазол кога вредноста е
  255.                   еднаква на value
  256.        :type tb: DecisionNode
  257.        :param fb: гранка која се дели од тековниот јазол кога вредноста е
  258.                   различна од value
  259.        :type fb: DecisionNode
  260.        """
  261.         self.col = col
  262.         self.value = value
  263.         self.results = results
  264.         self.tb = tb
  265.         self.fb = fb
  266.  
  267.  
  268. def compare_numerical(row, column, value):
  269.     """Споредба на вредноста од редицата на посакуваната колона со
  270.    зададена нумеричка вредност
  271.  
  272.    :param row: дадена редица во податочното множество
  273.    :type row: list
  274.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  275.    :type column: int
  276.    :param value: вредност на јазелот во согласност со кој се прави
  277.                  поделбата во дрвото
  278.    :type value: int or float
  279.    :return: True ако редицата >= value, инаку False
  280.    :rtype: bool
  281.    """
  282.     return row[column] >= value
  283.  
  284.  
  285. def compare_nominal(row, column, value):
  286.     """Споредба на вредноста од редицата на посакуваната колона со
  287.    зададена номинална вредност
  288.  
  289.    :param row: дадена редица во податочното множество
  290.    :type row: list
  291.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  292.    :type column: int
  293.    :param value: вредност на јазелот во согласност со кој се прави
  294.                  поделбата во дрвото
  295.    :type value: str
  296.    :return: True ако редицата == value, инаку False
  297.    :rtype: bool
  298.    """
  299.     return row[column] == value
  300.  
  301.  
  302. def divide_set(rows, column, value):
  303.     """Поделба на множеството според одредена колона. Може да се справи
  304.    со нумерички или номинални вредности.
  305.  
  306.    :param rows: тренирачко множество
  307.    :type rows: list(list)
  308.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  309.    :type column: int
  310.    :param value: вредност на јазелот во зависност со кој се прави поделбата
  311.                  во дрвото за конкретната гранка
  312.    :type value: int or float or str
  313.    :return: поделени подмножества
  314.    :rtype: list, list
  315.    """
  316.     # Направи функција која ни кажува дали редицата е во
  317.     # првата група (True) или втората група (False)
  318.     if isinstance(value, int) or isinstance(value, float):
  319.         # ако вредноста за споредба е од тип int или float
  320.         split_function = compare_numerical
  321.     else:
  322.         # ако вредноста за споредба е од друг тип (string)
  323.         split_function = compare_nominal
  324.  
  325.     # Подели ги редиците во две подмножества и врати ги
  326.     # за секој ред за кој split_function враќа True
  327.     set1 = [row for row in rows if
  328.             split_function(row, column, value)]
  329.     # set1 = []
  330.     # for row in rows:
  331.     #     if not split_function(row, column, value):
  332.     #         set1.append(row)
  333.     # за секој ред за кој split_function враќа False
  334.     set2 = [row for row in rows if
  335.             not split_function(row, column, value)]
  336.     return set1, set2
  337.  
  338.  
  339. def build_tree(rows, scoref=entropy):
  340.     """Градење на дрво на одлука.
  341.  
  342.    :param rows: тренирачко множество
  343.    :type rows: list(list)
  344.    :param scoref: функција за одбирање на најдобар атрибут во даден чекор
  345.    :type scoref: function
  346.    :return: коренот на изграденото дрво на одлука
  347.    :rtype: DecisionNode object
  348.    """
  349.     if len(rows) == 0:
  350.         return DecisionNode()
  351.     current_score = scoref(rows)
  352.  
  353.     # променливи со кои следиме кој критериум е најдобар
  354.     best_gain = 0.0
  355.     best_criteria = None
  356.     best_sets = None
  357.  
  358.     column_count = len(rows[0]) - 1
  359.     for col in range(0, column_count):
  360.         # за секоја колона (col се движи во интервалот од 0 до
  361.         # column_count - 1)
  362.         # Следниов циклус е за генерирање на речник од различни
  363.         # вредности во оваа колона
  364.         column_values = {}
  365.         for row in rows:
  366.             column_values[row[col]] = 1
  367.         # за секоја редица се зема вредноста во оваа колона и се
  368.         # поставува како клуч во column_values
  369.         for value in column_values.keys():
  370.             (set1, set2) = divide_set(rows, col, value)
  371.  
  372.             # Информациона добивка
  373.             p = float(len(set1)) / len(rows)
  374.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  375.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  376.                 best_gain = gain
  377.                 best_criteria = (col, value)
  378.                 best_sets = (set1, set2)
  379.  
  380.     # Креирај ги подгранките
  381.     if best_gain > 0:
  382.         true_branch = build_tree(best_sets[0], scoref)
  383.         false_branch = build_tree(best_sets[1], scoref)
  384.         return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  385.                             tb=true_branch, fb=false_branch)
  386.     else:
  387.         return DecisionNode(results=unique_counts(rows))
  388.  
  389.  
  390. def print_tree(tree, indent=''):
  391.     """Принтање на дрво на одлука
  392.  
  393.    :param tree: коренот на дрвото на одлучување
  394.    :type tree: DecisionNode object
  395.    :param indent:
  396.    :return: None
  397.    """
  398.     # Дали е ова лист јазел?
  399.     if tree.results:
  400.         print(str(tree.results))
  401.     else:
  402.         # Се печати условот
  403.         print(str(tree.col) + ':' + str(tree.value) + '? ')
  404.         # Се печатат True гранките, па False гранките
  405.         print(indent + 'T->', end='')
  406.         print_tree(tree.tb, indent + '  ')
  407.         print(indent + 'F->', end='')
  408.         print_tree(tree.fb, indent + '  ')
  409.  
  410.  
  411. def classify(observation, tree):
  412.     """Класификација на нов податочен примерок со изградено дрво на одлука
  413.  
  414.    :param observation: еден ред од податочното множество за предвидување
  415.    :type observation: list
  416.    :param tree: коренот на дрвото на одлучување
  417.    :type tree: DecisionNode object
  418.    :return: речник со класите како клуч и бројот на појавување во листот на дрвото
  419.    за класификација како вредност во речникот
  420.    :rtype: dict
  421.    """
  422.     if tree.results:
  423.         return tree.results
  424.     else:
  425.         value = observation[tree.col]
  426.         if isinstance(value, int) or isinstance(value, float):
  427.             compare = compare_numerical
  428.         else:
  429.             compare = compare_nominal
  430.  
  431.         if compare(observation, tree.col, tree.value):
  432.             branch = tree.tb
  433.         else:
  434.             branch = tree.fb
  435.  
  436.         return classify(observation, branch)
  437.  
  438.  
  439. if __name__ == "__main__":
  440.    test_case = input()
  441.    test_case = [float(x) for x in test_case.split(', ')[:-1]] + [test_case.split(', ')[-1]]
  442.    roach = []
  443.    pike = []
  444.    for d in data:
  445.        if d[6] == 'Roach':
  446.            roach.append(d)
  447.        elif d[6] == 'Pike':
  448.            pike.append(d)
  449.    roach = roach[:40]
  450.    pike = pike[:40]
  451.    train_data= roach+pike
  452.    t = build_tree(train_data,entropy)
  453.    klasa = classify(test_case,t)
  454.    print(sorted(list(klasa.items()),reverse=True,key=lambda x:x[1])[0][0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement