SHARE
TWEET

Untitled

a guest May 26th, 2019 80 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from math import log
  2.  
  3.  
  4. def unique_counts(rows):
  5.     """Креирај броење на можни резултати (последната колона
  6.    во секоја редица е класата)
  7.  
  8.    :param rows: dataset
  9.    :type rows: list
  10.    :return: dictionary of possible classes as keys and count
  11.             as values
  12.    :rtype: dict
  13.    """
  14.     results = {}
  15.     for row in rows:
  16.         # Клацата е последната колона
  17.         r = row[len(row) - 1]
  18.         if r not in results:
  19.             results[r] = 0
  20.         results[r] += 1
  21.     return results
  22.  
  23.  
  24. def gini_impurity(rows):
  25.     """Probability that a randomly placed item will
  26.    be in the wrong category
  27.  
  28.    :param rows: dataset
  29.    :type rows: list
  30.    :return: Gini impurity
  31.    :rtype: float
  32.    """
  33.     total = len(rows)
  34.     counts = unique_counts(rows)
  35.     imp = 0
  36.     for k1 in counts:
  37.         p1 = float(counts[k1]) / total
  38.         for k2 in counts:
  39.             if k1 == k2:
  40.                 continue
  41.             p2 = float(counts[k2]) / total
  42.             imp += p1 * p2
  43.     return imp
  44.  
  45.  
  46. def entropy(rows):
  47.     """Ентропијата е сума од p(x)log(p(x)) за сите
  48.    можни резултати
  49.  
  50.    :param rows: податочно множество
  51.    :type rows: list
  52.    :return: вредност за ентропијата
  53.    :rtype: float
  54.    """
  55.     log2 = lambda x: log(x) / log(2)
  56.     results = unique_counts(rows)
  57.     # Пресметка на ентропијата
  58.     ent = 0.0
  59.     for r in results.keys():
  60.         p = float(results[r]) / len(rows)
  61.         ent = ent - p * log2(p)
  62.     return ent
  63.  
  64.  
  65. class DecisionNode:
  66.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  67.         """
  68.        :param col: индексот на колоната (атрибутот) од тренинг множеството
  69.                    која се претставува со оваа инстанца т.е. со овој јазол
  70.        :type col: int
  71.        :param value: вредноста на јазолот според кој се дели дрвото
  72.        :param results: резултати за тековната гранка, вредност (различна
  73.                        од None) само кај јазлите-листови во кои се донесува
  74.                        одлуката.
  75.        :type results: dict
  76.        :param tb: гранка која се дели од тековниот јазол кога вредноста е
  77.                   еднаква на value
  78.        :type tb: DecisionNode
  79.        :param fb: гранка која се дели од тековниот јазол кога вредноста е
  80.                   различна од value
  81.        :type fb: DecisionNode
  82.        """
  83.         self.col = col
  84.         self.value = value
  85.         self.results = results
  86.         self.tb = tb
  87.         self.fb = fb
  88.  
  89.  
  90. def compare_numerical(row, column, value):
  91.     """Споредба на вредноста од редицата на посакуваната колона со
  92.    зададена нумеричка вредност
  93.  
  94.    :param row: дадена редица во податочното множество
  95.    :type row: list
  96.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  97.    :type column: int
  98.    :param value: вредност на јазелот во согласност со кој се прави
  99.                  поделбата во дрвото
  100.    :type value: int or float
  101.    :return: True ако редицата >= value, инаку False
  102.    :rtype: bool
  103.    """
  104.     return row[column] >= value
  105.  
  106.  
  107. def compare_nominal(row, column, value):
  108.     """Споредба на вредноста од редицата на посакуваната колона со
  109.    зададена номинална вредност
  110.  
  111.    :param row: дадена редица во податочното множество
  112.    :type row: list
  113.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  114.    :type column: int
  115.    :param value: вредност на јазелот во согласност со кој се прави
  116.                  поделбата во дрвото
  117.    :type value: str
  118.    :return: True ако редицата == value, инаку False
  119.    :rtype: bool
  120.    """
  121.     return row[column] == value
  122.  
  123.  
  124. def divide_set(rows, column, value):
  125.     """Поделба на множеството според одредена колона. Може да се справи
  126.    со нумерички или номинални вредности.
  127.  
  128.    :param rows: тренирачко множество
  129.    :type rows: list(list)
  130.    :param column: индекс на колоната (атрибутот) од тренирачкото множество
  131.    :type column: int
  132.    :param value: вредност на јазелот во зависност со кој се прави поделбата
  133.                  во дрвото за конкретната гранка
  134.    :type value: int or float or str
  135.    :return: поделени подмножества
  136.    :rtype: list, list
  137.    """
  138.     # Направи функција која ни кажува дали редицата е во
  139.     # првата група (True) или втората група (False)
  140.     if isinstance(value, int) or isinstance(value, float):
  141.         # ако вредноста за споредба е од тип int или float
  142.         split_function = compare_numerical
  143.     else:
  144.         # ако вредноста за споредба е од друг тип (string)
  145.         split_function = compare_nominal
  146.  
  147.     # Подели ги редиците во две подмножества и врати ги
  148.     # за секој ред за кој split_function враќа True
  149.     set1 = [row for row in rows if
  150.             split_function(row, column, value)]
  151.     # за секој ред за кој split_function враќа False
  152.     set2 = [row for row in rows if
  153.             not split_function(row, column, value)]
  154.     return set1, set2
  155.  
  156.  
  157. def build_tree(rows, scoref=entropy):
  158.     if len(rows) == 0:
  159.         return DecisionNode()
  160.     current_score = scoref(rows)
  161.  
  162.     # променливи со кои следиме кој критериум е најдобар
  163.     best_gain = 0.0
  164.     best_criteria = None
  165.     best_sets = None
  166.  
  167.     column_count = len(rows[0]) - 1
  168.     for col in range(0, column_count):
  169.         # за секоја колона (col се движи во интервалот од 0 до
  170.         # column_count - 1)
  171.         # Следниов циклус е за генерирање на речник од различни
  172.         # вредности во оваа колона
  173.         column_values = {}
  174.         for row in rows:
  175.             column_values[row[col]] = 1
  176.         # за секоја редица се зема вредноста во оваа колона и се
  177.         # поставува како клуч во column_values
  178.         for value in column_values.keys():
  179.             (set1, set2) = divide_set(rows, col, value)
  180.  
  181.             # Информациона добивка
  182.             p = float(len(set1)) / len(rows)
  183.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  184.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  185.                 best_gain = gain
  186.                 best_criteria = (col, value)
  187.                 best_sets = (set1, set2)
  188.  
  189.     # Креирај ги подгранките
  190.     if best_gain > 0:
  191.         true_branch = build_tree(best_sets[0], scoref)
  192.         false_branch = build_tree(best_sets[1], scoref)
  193.         return DecisionNode(col=best_criteria[0], value=best_criteria[1],
  194.                             tb=true_branch, fb=false_branch)
  195.     else:
  196.         return DecisionNode(results=unique_counts(rows))
  197.  
  198.  
  199. def print_tree(tree, indent=''):
  200.     # Дали е ова лист јазел?
  201.     if tree.results:
  202.         print(str(tree.results))
  203.     else:
  204.         # Се печати условот
  205.         print(str(tree.col) + ':' + str(tree.value) + '? ')
  206.         # Се печатат True гранките, па False гранките
  207.         print(indent + 'T->', end='')
  208.         print_tree(tree.tb, indent + '  ')
  209.         print(indent + 'F->', end='')
  210.         print_tree(tree.fb, indent + '  ')
  211.  
  212.  
  213. def classify(observation, tree):
  214.     if tree.results:
  215.         return tree.results
  216.     else:
  217.         value = observation[tree.col]
  218.         if isinstance(value, int) or isinstance(value, float):
  219.             compare = compare_numerical
  220.         else:
  221.             compare = compare_nominal
  222.  
  223.         if compare(observation, tree.col, tree.value):
  224.             branch = tree.tb
  225.         else:
  226.             branch = tree.fb
  227.  
  228.         return classify(observation, branch)
  229.  
  230. if __name__ == '__main__':
  231.     my_data = [[0, 0, 0, 0, 'TRUE'],
  232.                [2, 1, 0, 1, 'TRUE'],
  233.                [0, 2, 0, 2, 'FALSE'],
  234.                [2, 0, 2, 1, 'FALSE'],
  235.                [2, 2, 2, 1, 'TRUE'],
  236.                [2, 0, 2, 2, 'TRUE'],
  237.                [0, 1, 0, 1, 'TRUE'],
  238.                [2, 1, 0, 0, 'TRUE'],
  239.                [2, 1, 2, 1, 'TRUE'],
  240.                [0, 2, 0, 1, 'FALSE'],
  241.                [0, 1, 2, 1, 'FALSE'],
  242.                [0, 1, 2, 2, 'TRUE'],
  243.                [0, 0, 2, 0, 'FALSE'],
  244.                [0, 1, 2, 1, 'TRUE'],
  245.                [0, 2, 2, 1, 'TRUE'],
  246.                [2, 0, 2, 2, 'FALSE'],
  247.                [0, 1, 2, 1, 'FALSE'],
  248.                [2, 2, 0, 2, 'TRUE'],
  249.                [2, 1, 2, 0, 'TRUE'],
  250.                [2, 1, 0, 1, 'TRUE'],
  251.                [0, 1, 0, 1, 'TRUE'],
  252.                [0, 2, 2, 0, 'TRUE'],
  253.                [0, 1, 2, 1, 'TRUE'],
  254.                [2, 0, 0, 2, 'TRUE']]
  255.  
  256.     tree = build_tree(my_data)
  257.     print_tree(tree)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top