lolipop12

[SNZ]Класификација со мнозинство на гласови

Nov 17th, 2019
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.89 KB | None | 0 0
  1. from math import log
  2.  
  3.  
  4. def unique_counts(rows):
  5.     """Креирај броење на можни резултати (последната колона
  6.    во секоја редица е класата)
  7.  
  8.    :param rows: dataset
  9.    :type rows: list
  10.    :return: dictionary of possible classes as keys and count
  11.             as values
  12.    :rtype: dict
  13.    """
  14.     results = {}
  15.     for row in rows:
  16.         # Клацата е последната колона
  17.         r = row[len(row) - 1]
  18.         if r not in results:
  19.             results[r] = 0
  20.         results[r] += 1
  21.     return results
  22.  
  23.  
  24. def gini_impurity(rows):
  25.     """Probability that a randomly placed item will
  26.    be in the wrong category
  27.  
  28.    :param rows: dataset
  29.    :type rows: list
  30.    :return: Gini impurity
  31.    :rtype: float
  32.    """
  33.     total = len(rows)
  34.     counts = unique_counts(rows)
  35.     imp = 0
  36.     for k1 in counts:
  37.         p1 = float(counts[k1]) / total
  38.         for k2 in counts:
  39.             if k1 == k2:
  40.                 continue
  41.             p2 = float(counts[k2]) / total
  42.             imp += p1 * p2
  43.     return imp
  44.  
  45.  
  46. def entropy(rows):
  47.     """Ентропијата е сума од p(x)log(p(x)) за сите
  48.    можни резултати
  49.  
  50.    :param rows: податочно множество
  51.    :type rows: list
  52.    :return: вредност за ентропијата
  53.    :rtype: float
  54.    """
  55.     log2 = lambda x: log(x) / log(2)
  56.     results = unique_counts(rows)
  57.     # Пресметка на ентропијата
  58.     ent = 0.0
  59.     for r in results.keys():
  60.         p = float(results[r]) / len(rows)
  61.         ent = ent - p * log2(p)
  62.     return ent
  63.  
  64.  
  65. class DecisionNode:
  66.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  67.         """
  68.        :param col: индексот на колоната (атрибутот) од тренинг множеството
  69.                    која се претставува со оваа инстанца т.е. со овој јазол
  70.        :type col: int
  71.        :param value: вредноста на јазолот според кој се дели дрвото
  72.        :param results: резултати за тековната гранка, вредност (различна
  73.                        од None) само кај јазлите-листови во кои се донесува
  74.                        одлуката.
  75.        :type results: dict
  76.        :param tb: гранка која се дели од тековниот јазол кога вредноста е
  77.                   еднаква на value
  78.        :type tb: DecisionNode
  79.        :param fb: гранка која се дели од тековниот јазол кога вредноста е
  80.                   различна од value
  81.        :type fb: DecisionNode
  82.        """
  83.         self.col = col
  84.         self.value = value
  85.         self.results = results
  86.         self.tb = tb
  87.         self.fb = fb
  88.  
  89.  
  90. def compare_numerical(row, column, value):
  91.     """Споредба на вредноста од редицата на посакуваната колона со
  92.    зададена нумеричка вредност
  93.  
  94.    :param row: дадена редица во податочното множество
  95.    :type row: list
  96.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  97.    :type column: int
  98.    :param value: вредност на јазелот во согласност со кој се прави
  99.                  поделбата во дрвото
  100.    :type value: int or float
  101.    :return: True ако редицата >= value, инаку False
  102.    :rtype: bool
  103.    """
  104.     return row[column] >= value
  105.  
  106.  
  107. def compare_nominal(row, column, value):
  108.     """Споредба на вредноста од редицата на посакуваната колона со
  109.    зададена номинална вредност
  110.  
  111.    :param row: дадена редица во податочното множество
  112.    :type row: list
  113.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  114.    :type column: int
  115.    :param value: вредност на јазелот во согласност со кој се прави
  116.                  поделбата во дрвото
  117.    :type value: str
  118.    :return: True ако редицата == value, инаку False
  119.    :rtype: bool
  120.    """
  121.     return row[column] == value
  122.  
  123.  
  124. def divide_set(rows, column, value):
  125.     """Поделба на множеството според одредена колона. Може да се справи
  126.    со нумерички или номинални вредности.
  127.  
  128.    :param rows: тренирачко множество
  129.    :type rows: list(list)
  130.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  131.    :type column: int
  132.    :param value: вредност на јазелот во зависност со кој се прави поделбата
  133.                  во дрвото за конкретната гранка
  134.    :type value: int or float or str
  135.    :return: поделени подмножества
  136.    :rtype: list, list
  137.    """
  138.     # Направи функција која ни кажува дали редицата е во
  139.     # првата група (True) или втората група (False)
  140.     if isinstance(value, int) or isinstance(value, float):
  141.         # ако вредноста за споредба е од тип int или float
  142.         split_function = compare_numerical
  143.     else:
  144.         # ако вредноста за споредба е од друг тип (string)
  145.         split_function = compare_nominal
  146.  
  147.     # Подели ги редиците во две подмножества и врати ги
  148.     # за секој ред за кој split_function враќа True
  149.     set1 = [row for row in rows if
  150.             split_function(row, column, value)]
  151.     # set1 = []
  152.     # for row in rows:
  153.     #     if not split_function(row, column, value):
  154.     #         set1.append(row)
  155.     # за секој ред за кој split_function враќа False
  156.     set2 = [row for row in rows if
  157.             not split_function(row, column, value)]
  158.     return set1, set2
  159.  
  160.  
  161. def build_tree(rows, scoref=entropy):
  162.     if len(rows) == 0:
  163.         return DecisionNode()
  164.     current_score = scoref(rows)
  165.  
  166.     # променливи со кои следиме кој критериум е најдобар
  167.     best_gain = 0.0
  168.     best_criteria = None
  169.     best_sets = None
  170.  
  171.     column_count = len(rows[0]) - 1
  172.     for col in range(0, column_count):
  173.         # за секоја колона (col се движи во интервалот од 0 до
  174.         # column_count - 1)
  175.         # Следниов циклус е за генерирање на речник од различни
  176.         # вредности во оваа колона
  177.         column_values = {}
  178.         for row in rows:
  179.             column_values[row[col]] = 1
  180.         # за секоја редица се зема вредноста во оваа колона и се
  181.         # поставува како клуч во column_values
  182.         for value in column_values.keys():
  183.             (set1, set2) = divide_set(rows, col, value)
  184.  
  185.             # Информациона добивка
  186.             p = float(len(set1)) / len(rows)
  187.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  188.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  189.                 best_gain = gain
  190.                 best_criteria = (col, value)
  191.                 best_sets = (set1, set2)
  192.  
  193.     # Креирај ги подгранките
  194.     if best_gain > 0:
  195.         true_branch = build_tree(best_sets[0], scoref)
  196.         false_branch = build_tree(best_sets[1], scoref)
  197.         return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  198.                             tb=true_branch, fb=false_branch)
  199.     else:
  200.         return DecisionNode(results=unique_counts(rows))
  201.  
  202.  
  203. def print_tree(tree, indent=''):
  204.     # Дали е ова лист јазел?
  205.     if tree.results:
  206.         print(str(tree.results))
  207.     else:
  208.         # Се печати условот
  209.         print(str(tree.col) + ':' + str(tree.value) + '? ')
  210.         # Се печатат True гранките, па False гранките
  211.         print(indent + 'T->', end='')
  212.         print_tree(tree.tb, indent + '  ')
  213.         print(indent + 'F->', end='')
  214.         print_tree(tree.fb, indent + '  ')
  215.  
  216.  
  217. def classify(observation, tree):
  218.     if tree.results:
  219.         return tree.results
  220.     else:
  221.         value = observation[tree.col]
  222.         if isinstance(value, int) or isinstance(value, float):
  223.             compare = compare_numerical
  224.         else:
  225.             compare = compare_nominal
  226.  
  227.         if compare(observation, tree.col, tree.value):
  228.             branch = tree.tb
  229.         else:
  230.             branch = tree.fb
  231.  
  232.         return classify(observation, branch)
  233. dataset = [[6.3, 2.3, 4.4, 1.3, 2],
  234.            [6.4, 2.8, 5.6, 2.1, 0],
  235.            [5.1, 3.3, 1.7, 0.5, 1],
  236.            [5.1, 3.5, 1.4, 0.2, 1],
  237.            [4.6, 3.1, 1.5, 0.2, 1],
  238.            [5.8, 2.7, 5.1, 1.9, 0],
  239.            [5.5, 3.5, 1.3, 0.2, 1],
  240.            [5.7, 2.6, 3.5, 1.0, 2],
  241.            [5.0, 3.5, 1.3, 0.3, 1],
  242.            [6.3, 2.5, 5.0, 1.9, 0],
  243.            [6.2, 2.2, 4.5, 1.5, 2],
  244.            [5.0, 3.4, 1.6, 0.4, 1],
  245.            [5.7, 4.4, 1.5, 0.4, 1],
  246.            [4.9, 2.4, 3.3, 1.0, 2],
  247.            [4.4, 2.9, 1.4, 0.2, 1],
  248.            [5.5, 2.4, 3.7, 1.0, 2],
  249.            [5.6, 2.5, 3.9, 1.1, 2],
  250.            [5.6, 2.8, 4.9, 2.0, 0],
  251.            [4.8, 3.4, 1.6, 0.2, 1],
  252.            [5.6, 3.0, 4.5, 1.5, 2],
  253.            [6.0, 3.0, 4.8, 1.8, 0],
  254.            [6.3, 3.3, 4.7, 1.6, 2],
  255.            [4.8, 3.0, 1.4, 0.1, 1],
  256.            [7.9, 3.8, 6.4, 2.0, 0],
  257.            [4.9, 3.0, 1.4, 0.2, 1],
  258.            [4.3, 3.0, 1.1, 0.1, 1],
  259.            [6.8, 3.2, 5.9, 2.3, 0],
  260.            [5.6, 2.7, 4.2, 1.3, 2],
  261.            [5.2, 4.1, 1.5, 0.1, 1],
  262.            [6.2, 2.9, 4.3, 1.3, 2],
  263.            [6.5, 2.8, 4.6, 1.5, 2],
  264.            [5.4, 3.9, 1.3, 0.4, 1],
  265.            [5.8, 2.6, 4.0, 1.2, 2],
  266.            [5.4, 3.7, 1.5, 0.2, 1],
  267.            [4.5, 2.3, 1.3, 0.3, 1],
  268.            [6.3, 3.4, 5.6, 2.4, 0],
  269.            [6.2, 3.4, 5.4, 2.3, 0],
  270.            [5.7, 2.5, 5.0, 2.0, 0],
  271.            [5.8, 2.7, 3.9, 1.2, 2],
  272.            [6.4, 2.7, 5.3, 1.9, 0],
  273.            [5.1, 3.8, 1.6, 0.2, 1],
  274.            [6.3, 2.5, 4.9, 1.5, 2],
  275.            [7.7, 2.8, 6.7, 2.0, 0],
  276.            [5.1, 3.5, 1.4, 0.3, 1],
  277.            [6.8, 2.8, 4.8, 1.4, 2],
  278.            [6.1, 3.0, 4.6, 1.4, 2],
  279.            [5.5, 4.2, 1.4, 0.2, 1],
  280.            [5.0, 2.0, 3.5, 1.0, 2],
  281.            [7.7, 3.0, 6.1, 2.3, 0],
  282.            [5.1, 2.5, 3.0, 1.1, 2],
  283.            [5.9, 3.0, 5.1, 1.8, 0],
  284.            [7.2, 3.2, 6.0, 1.8, 0],
  285.            [4.9, 3.1, 1.5, 0.2, 1],
  286.            [5.7, 3.0, 4.2, 1.2, 2],
  287.            [6.1, 2.9, 4.7, 1.4, 2],
  288.            [5.0, 3.2, 1.2, 0.2, 1],
  289.            [4.4, 3.2, 1.3, 0.2, 1],
  290.            [6.7, 3.1, 5.6, 2.4, 0],
  291.            [4.6, 3.6, 1.0, 0.2, 1],
  292.            [5.1, 3.4, 1.5, 0.2, 1],
  293.            [5.2, 2.7, 3.9, 1.4, 2],
  294.            [6.4, 3.1, 5.5, 1.8, 0],
  295.            [7.4, 2.8, 6.1, 1.9, 0],
  296.            [4.9, 3.1, 1.5, 0.1, 1],
  297.            [5.0, 3.5, 1.6, 0.6, 1],
  298.            [6.7, 3.1, 4.7, 1.5, 2],
  299.            [6.4, 3.2, 5.3, 2.3, 0],
  300.            [6.3, 2.7, 4.9, 1.8, 0],
  301.            [5.8, 4.0, 1.2, 0.2, 1],
  302.            [6.9, 3.1, 5.4, 2.1, 0],
  303.            [5.9, 3.2, 4.8, 1.8, 2],
  304.            [6.6, 2.9, 4.6, 1.3, 2],
  305.            [6.1, 2.8, 4.0, 1.3, 2],
  306.            [7.7, 2.6, 6.9, 2.3, 0],
  307.            [5.5, 2.6, 4.4, 1.2, 2],
  308.            [6.3, 2.9, 5.6, 1.8, 0],
  309.            [7.2, 3.0, 5.8, 1.6, 0],
  310.            [6.5, 3.0, 5.8, 2.2, 0],
  311.            [5.4, 3.9, 1.7, 0.4, 1],
  312.            [6.5, 3.2, 5.1, 2.0, 0],
  313.            [5.9, 3.0, 4.2, 1.5, 2],
  314.            [5.1, 3.7, 1.5, 0.4, 1],
  315.            [5.7, 2.8, 4.5, 1.3, 2],
  316.            [5.4, 3.4, 1.5, 0.4, 1],
  317.            [4.6, 3.4, 1.4, 0.3, 1],
  318.            [4.9, 3.6, 1.4, 0.1, 1],
  319.            [6.7, 2.5, 5.8, 1.8, 0],
  320.            [5.0, 3.6, 1.4, 0.2, 1],
  321.            [6.7, 3.3, 5.7, 2.5, 0],
  322.            [4.4, 3.0, 1.3, 0.2, 1],
  323.            [6.0, 2.2, 5.0, 1.5, 0],
  324.            [6.0, 2.2, 4.0, 1.0, 2],
  325.            [5.0, 3.4, 1.5, 0.2, 1],
  326.            [5.7, 2.8, 4.1, 1.3, 2],
  327.            [5.5, 2.4, 3.8, 1.1, 2],
  328.            [5.1, 3.8, 1.9, 0.4, 1],
  329.            [6.9, 3.1, 5.1, 2.3, 0],
  330.            [5.6, 2.9, 3.6, 1.3, 2],
  331.            [6.1, 2.8, 4.7, 1.2, 2],
  332.            [5.5, 2.5, 4.0, 1.3, 2],
  333.            [5.5, 2.3, 4.0, 1.3, 2],
  334.            [6.0, 2.9, 4.5, 1.5, 2],
  335.            [5.1, 3.8, 1.5, 0.3, 1],
  336.            [5.7, 3.8, 1.7, 0.3, 1],
  337.            [6.7, 3.3, 5.7, 2.1, 0],
  338.            [4.8, 3.1, 1.6, 0.2, 1],
  339.            [5.4, 3.0, 4.5, 1.5, 2],
  340.            [6.5, 3.0, 5.2, 2.0, 0],
  341.            [6.8, 3.0, 5.5, 2.1, 0],
  342.            [7.6, 3.0, 6.6, 2.1, 0],
  343.            [5.0, 3.0, 1.6, 0.2, 1],
  344.            [6.7, 3.0, 5.0, 1.7, 2],
  345.            [4.8, 3.4, 1.9, 0.2, 1],
  346.            [5.8, 2.8, 5.1, 2.4, 0],
  347.            [5.0, 2.3, 3.3, 1.0, 2],
  348.            [4.8, 3.0, 1.4, 0.3, 1],
  349.            [5.2, 3.5, 1.5, 0.2, 1],
  350.            [6.1, 2.6, 5.6, 1.4, 0],
  351.            [5.8, 2.7, 4.1, 1.0, 2],
  352.            [6.9, 3.2, 5.7, 2.3, 0],
  353.            [6.4, 2.9, 4.3, 1.3, 2],
  354.            [7.3, 2.9, 6.3, 1.8, 0],
  355.            [6.3, 2.8, 5.1, 1.5, 0],
  356.            [6.2, 2.8, 4.8, 1.8, 0],
  357.            [6.7, 3.1, 4.4, 1.4, 2],
  358.            [6.0, 2.7, 5.1, 1.6, 2],
  359.            [6.5, 3.0, 5.5, 1.8, 0],
  360.            [6.1, 3.0, 4.9, 1.8, 0],
  361.            [5.6, 3.0, 4.1, 1.3, 2],
  362.            [4.7, 3.2, 1.6, 0.2, 1],
  363.            [6.6, 3.0, 4.4, 1.4, 2]]
  364.  
  365. if __name__ == '__main__':
  366.     x = input() .split(', ')
  367.     test_case = list(map(float, x[:-1])) + [int(x[-1])]
  368.  
  369.     mnoz1 = dataset[:int(len(dataset) * 0.3)]
  370.     mnoz2 = dataset[int(len(dataset) * 0.3):int(len(dataset) * 0.6)]
  371.     mnoz3 = dataset[int(len(dataset) * 0.6):]
  372.  
  373.     drvo1 = build_tree(mnoz1, entropy)
  374.     drvo2 = build_tree(mnoz2, entropy)
  375.     drvo3 = build_tree(mnoz3, entropy)
  376.  
  377.     glasovi = {0: 0, 1: 0, 2: 0}
  378.  
  379.     klasa1 = max(classify(test_case, drvo1).items(), key=lambda x: x[1])[0]
  380.     klasa2 = max(classify(test_case, drvo2).items(), key=lambda x: x[1])[0]
  381.     klasa3 = max(classify(test_case, drvo3).items(), key=lambda x: x[1])[0]
  382.  
  383.     glasovi[klasa1] += 1
  384.     glasovi[klasa2] += 1
  385.     glasovi[klasa3] += 1
  386.  
  387.     sorted_klasi = list(sorted(glasovi.items(), key=lambda x: x[1], reverse=True))
  388.  
  389.    
  390.     klasa = sorted_klasi[0][0]
  391.     print(sorted_klasi)
  392.    
  393.    
  394.     if sorted_klasi[0][1] == sorted_klasi[1][1]:
  395.         print('unknown')
  396.     else:
  397.         print("Predvidena klasa:",klasa)
Add Comment
Please, Sign In to add comment