Advertisement
glavinova

[СНЗ] 1. Дрво на одлучување - лаб1

Jul 5th, 2020
1,628
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.93 KB | None | 0 0
  1. """Дрво на одлучување Problem 1 (2 / 7)
  2. Да се промени класата за дрво на одлука за да чува и информација на кое ниво во дрвото се наоѓа јазолот. Потоа да се променат и функциите за градење и печатење на дрвото така што за секој јазол се додава информација за нивото и се печати и нивото. Коренот е на нулто ниво. Со функцијата print_tree треба да се испечати креираното дрво на одлука. Прочитана инстанца од стандарден влез да се додаде на тренинг множеството и потоа да се истренира и испечати дрвото на одлука со ова податочно множество."""
  3.  
  4. from math import log
  5.  
  6.  
  7. def unique_counts(rows):
  8.     """Креирај броење на можни резултати (последната колона
  9.    во секоја редица е класата)
  10.  
  11.    :param rows: dataset
  12.    :type rows: list
  13.    :return: dictionary of possible classes as keys and count
  14.             as values
  15.    :rtype: dict
  16.    """
  17.     results = {}
  18.     for row in rows:
  19.         # Клацата е последната колона
  20.         r = row[len(row) - 1]
  21.         if r not in results:
  22.             results[r] = 0
  23.         results[r] += 1
  24.     return results
  25.  
  26.  
  27. def gini_impurity(rows):
  28.     """Probability that a randomly placed item will
  29.    be in the wrong category
  30.  
  31.    :param rows: dataset
  32.    :type rows: list
  33.    :return: Gini impurity
  34.    :rtype: float
  35.    """
  36.     total = len(rows)
  37.     counts = unique_counts(rows)
  38.     imp = 0
  39.     for k1 in counts:
  40.         p1 = float(counts[k1]) / total
  41.         for k2 in counts:
  42.             if k1 == k2:
  43.                 continue
  44.             p2 = float(counts[k2]) / total
  45.             imp += p1 * p2
  46.     return imp
  47.  
  48.  
  49. def entropy(rows):
  50.     """Ентропијата е сума од p(x)log(p(x)) за сите
  51.    можни резултати
  52.  
  53.    :param rows: податочно множество
  54.    :type rows: list
  55.    :return: вредност за ентропијата
  56.    :rtype: float
  57.    """
  58.     log2 = lambda x: log(x) / log(2)
  59.     results = unique_counts(rows)
  60.     # Пресметка на ентропијата
  61.     ent = 0.0
  62.     for r in results.keys():
  63.         p = float(results[r]) / len(rows)
  64.         ent = ent - p * log2(p)
  65.     return ent
  66.  
  67.  
  68. class DecisionNode:
  69.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None, level = None):
  70.         """
  71.        :param col: индексот на колоната (атрибутот) од тренинг множеството
  72.                    која се претставува со оваа инстанца т.е. со овој јазол
  73.        :type col: int
  74.        :param value: вредноста на јазолот според кој се дели дрвото
  75.        :param results: резултати за тековната гранка, вредност (различна
  76.                        од None) само кај јазлите-листови во кои се донесува
  77.                        одлуката.
  78.        :type results: dict
  79.        :param tb: гранка која се дели од тековниот јазол кога вредноста е
  80.                   еднаква на value
  81.        :type tb: DecisionNode
  82.        :param fb: гранка која се дели од тековниот јазол кога вредноста е
  83.                   различна од value
  84.        :type fb: DecisionNode
  85.        """
  86.         self.col = col
  87.         self.value = value
  88.         self.results = results
  89.         self.tb = tb
  90.         self.fb = fb
  91.         self.level = level
  92.  
  93.  
  94. def compare_numerical(row, column, value):
  95.     """Споредба на вредноста од редицата на посакуваната колона со
  96.    зададена нумеричка вредност
  97.  
  98.    :param row: дадена редица во податочното множество
  99.    :type row: list
  100.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  101.    :type column: int
  102.    :param value: вредност на јазелот во согласност со кој се прави
  103.                  поделбата во дрвото
  104.    :type value: int or float
  105.    :return: True ако редицата >= value, инаку False
  106.    :rtype: bool
  107.    """
  108.     return row[column] >= value
  109.  
  110.  
  111. def compare_nominal(row, column, value):
  112.     """Споредба на вредноста од редицата на посакуваната колона со
  113.    зададена номинална вредност
  114.  
  115.    :param row: дадена редица во податочното множество
  116.    :type row: list
  117.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  118.    :type column: int
  119.    :param value: вредност на јазелот во согласност со кој се прави
  120.                  поделбата во дрвото
  121.    :type value: str
  122.    :return: True ако редицата == value, инаку False
  123.    :rtype: bool
  124.    """
  125.     return row[column] == value
  126.  
  127.  
  128. def divide_set(rows, column, value):
  129.     """Поделба на множеството според одредена колона. Може да се справи
  130.    со нумерички или номинални вредности.
  131.  
  132.    :param rows: тренирачко множество
  133.    :type rows: list(list)
  134.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  135.    :type column: int
  136.    :param value: вредност на јазелот во зависност со кој се прави поделбата
  137.                  во дрвото за конкретната гранка
  138.    :type value: int or float or str
  139.    :return: поделени подмножества
  140.    :rtype: list, list
  141.    """
  142.     # Направи функција која ни кажува дали редицата е во
  143.     # првата група (True) или втората група (False)
  144.     if isinstance(value, int) or isinstance(value, float):
  145.         # ако вредноста за споредба е од тип int или float
  146.         split_function = compare_numerical
  147.     else:
  148.         # ако вредноста за споредба е од друг тип (string)
  149.         split_function = compare_nominal
  150.  
  151.     # Подели ги редиците во две подмножества и врати ги
  152.     # за секој ред за кој split_function враќа True
  153.     set1 = [row for row in rows if
  154.             split_function(row, column, value)]
  155.     # set1 = []
  156.     # for row in rows:
  157.     #     if not split_function(row, column, value):
  158.     #         set1.append(row)
  159.     # за секој ред за кој split_function враќа False
  160.     set2 = [row for row in rows if
  161.             not split_function(row, column, value)]
  162.     return set1, set2
  163.  
  164.  
  165. def build_tree(rows, scoref=entropy, level = 0):
  166.     """Градење на дрво на одлука.
  167.  
  168.    :param rows: тренирачко множество
  169.    :type rows: list(list)
  170.    :param scoref: функција за одбирање на најдобар атрибут во даден чекор
  171.    :type scoref: function
  172.    :return: коренот на изграденото дрво на одлука
  173.    :rtype: DecisionNode object
  174.    """
  175.     if len(rows) == 0:
  176.         return DecisionNode()
  177.     current_score = scoref(rows)
  178.  
  179.     # променливи со кои следиме кој критериум е најдобар
  180.     best_gain = 0.0
  181.     best_criteria = None
  182.     best_sets = None
  183.  
  184.     column_count = len(rows[0]) - 1
  185.     for col in range(0, column_count):
  186.         # за секоја колона (col се движи во интервалот од 0 до
  187.         # column_count - 1)
  188.         # Следниов циклус е за генерирање на речник од различни
  189.         # вредности во оваа колона
  190.         column_values = {}
  191.         for row in rows:
  192.             column_values[row[col]] = 1
  193.         # за секоја редица се зема вредноста во оваа колона и се
  194.         # поставува како клуч во column_values
  195.         for value in column_values.keys():
  196.             (set1, set2) = divide_set(rows, col, value)
  197.  
  198.             # Информациона добивка
  199.             p = float(len(set1)) / len(rows)
  200.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  201.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  202.                 best_gain = gain
  203.                 best_criteria = (col, value)
  204.                 best_sets = (set1, set2)
  205.  
  206.     # Креирај ги подгранките
  207.     if best_gain > 0:
  208.         true_branch = build_tree(best_sets[0], scoref, level+1)
  209.         false_branch = build_tree(best_sets[1], scoref, level+1)
  210.         return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  211.                             tb=true_branch, fb=false_branch, level=level)
  212.     else:
  213.         return DecisionNode(results=unique_counts(rows),level=0)
  214.  
  215.  
  216. def print_tree(tree, indent=''):
  217.     """Принтање на дрво на одлука
  218.  
  219.    :param tree: коренот на дрвото на одлучување
  220.    :type tree: DecisionNode object
  221.    :param indent:
  222.    :return: None
  223.    """
  224.     # Дали е ова лист јазел?
  225.     if tree.results:
  226.         print(str(tree.results))
  227.     else:
  228.         # Се печати условот
  229.         print(str(tree.col) + ':' + str(tree.value) + '? ', 'Level=' +str(tree.level))
  230.         # Се печатат True гранките, па False гранките
  231.         print(indent + 'T->', end='')
  232.         print_tree(tree.tb, indent + '  ')
  233.         print(indent + 'F->', end='')
  234.         print_tree(tree.fb, indent + '  ')
  235.  
  236.  
  237. def classify(observation, tree):
  238.     """Класификација на нов податочен примерок со изградено дрво на одлука
  239.  
  240.    :param observation: еден ред од податочното множество за предвидување
  241.    :type observation: list
  242.    :param tree: коренот на дрвото на одлучување
  243.    :type tree: DecisionNode object
  244.    :return: речник со класите како клуч и бројот на појавување во листот на дрвото
  245.    за класификација како вредност во речникот
  246.    :rtype: dict
  247.    """
  248.     if tree.results:
  249.         return tree.results
  250.     else:
  251.         value = observation[tree.col]
  252.         if isinstance(value, int) or isinstance(value, float):
  253.             compare = compare_numerical
  254.         else:
  255.             compare = compare_nominal
  256.  
  257.         if compare(observation, tree.col, tree.value):
  258.             branch = tree.tb
  259.         else:
  260.             branch = tree.fb
  261.  
  262.         return classify(observation, branch)
  263.  
  264. training_data = [['slashdot', 'USA', 'yes', 18, 'None'],
  265.                  ['google', 'France', 'yes', 23, 'Premium'],
  266.                  ['google', 'France', 'yes', 23, 'Basic'],
  267.                  ['google', 'France', 'yes', 23, 'Basic'],
  268.                  ['digg', 'USA', 'yes', 24, 'Basic'],
  269.                  ['kiwitobes', 'France', 'yes', 23, 'Basic'],
  270.                  ['google', 'UK', 'no', 21, 'Premium'],
  271.                  ['(direct)', 'New Zealand', 'no', 12, 'None'],
  272.                  ['(direct)', 'UK', 'no', 21, 'Basic'],
  273.                  ['google', 'USA', 'no', 24, 'Premium'],
  274.                  ['slashdot', 'France', 'yes', 19, 'None'],
  275.                  ['digg', 'USA', 'no', 18, 'None'],
  276.                  ['google', 'UK', 'no', 18, 'None'],
  277.                  ['kiwitobes', 'UK', 'no', 19, 'None'],
  278.                  ['digg', 'New Zealand', 'yes', 12, 'Basic'],
  279.                  ['slashdot', 'UK', 'no', 21, 'None'],
  280.                  ['google', 'UK', 'yes', 18, 'Basic'],
  281.                  ['kiwitobes', 'France', 'yes', 19, 'Basic']]
  282.  
  283.  
  284. if __name__ == "__main__":
  285.     referrer = input()
  286.     location = input()
  287.     readFAQ = input()
  288.     pagesVisited = int(input())
  289.     serviceChosen = input()
  290.  
  291.     testCase = [referrer, location, readFAQ, pagesVisited, serviceChosen]
  292.     training_data.append(testCase)
  293.     t = build_tree(training_data,entropy)
  294.     klasa = classify(testCase,t)
  295.     print_tree(t)
  296.     #print(sorted(list(klasa.items()),reverse=True,key=lambda x: x[1])[0][0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement