Advertisement
nanorocks

python_lecture_decision_trees

Nov 5th, 2017
250
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.52 KB | None | 0 0
  1. from __future__ import print_function
  2.  
  3. my_data = [['slashdot', 'USA', 'yes', 18, 'None'],
  4.            ['google', 'France', 'yes', 23, 'Premium'],
  5.            ['digg', 'USA', 'yes', 24, 'Basic'],
  6.            ['kiwitobes', 'France', 'yes', 23, 'Basic'],
  7.            ['google', 'UK', 'no', 21, 'Premium'],
  8.            ['(direct)', 'New Zealand', 'no', 12, 'None'],
  9.            ['(direct)', 'UK', 'no', 21, 'Basic'],
  10.            ['google', 'USA', 'no', 24, 'Premium'],
  11.            ['slashdot', 'France', 'yes', 19, 'None'],
  12.            ['digg', 'USA', 'no', 18, 'None'],
  13.            ['google', 'UK', 'no', 18, 'None'],
  14.            ['kiwitobes', 'UK', 'no', 19, 'None'],
  15.            ['digg', 'New Zealand', 'yes', 12, 'Basic'],
  16.            ['slashdot', 'UK', 'no', 21, 'None'],
  17.            ['google', 'UK', 'yes', 18, 'Basic'],
  18.            ['kiwitobes', 'France', 'yes', 19, 'Basic']]
  19.  
  20. test_cases = [['google', 'MK', 'no', 24, 'Unknown'],
  21.               ['google', 'MK', 'no', 15, 'Unknown'],
  22.               ['digg', 'UK', 'yes', 21, 'Unknown'],
  23.               ['digg', 'UK', 'no', 25, 'Unknown']]
  24.  
  25.  
  26. # my_data=[line.split('\t') for line in file('decision_tree_example.txt')]
  27.  
  28. class decisionnode:
  29.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  30.         self.col = col
  31.         self.value = value
  32.         self.results = results
  33.         self.tb = tb
  34.         self.fb = fb
  35.  
  36.  
  37. def sporedi_broj(row, column, value):
  38.     return row[column] >= value
  39.  
  40.  
  41. def sporedi_string(row, column, value):
  42.     return row[column] == value
  43.  
  44.  
  45. # Divides a set on a specific column. Can handle numeric
  46. # or nominal values
  47. def divideset(rows, column, value):
  48.     # Make a function that tells us if a row is in
  49.     # the first group (true) or the second group (false)
  50.     split_function = None
  51.     if isinstance(value, int) or isinstance(value, float):  # ako vrednosta so koja sporeduvame e od tip int ili float
  52.         # split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  53.         split_function = sporedi_broj
  54.     else:
  55.         # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  56.         split_function = sporedi_string
  57.  
  58.     # Divide the rows into two sets and return them
  59.     set_false = []
  60.     set_true = []
  61.     for row in rows:
  62.         if split_function(row, column, value):
  63.             set_true.append(row)
  64.         else:
  65.             set_false.append(row)
  66.     set1 = [row for row in rows if
  67.             split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja true
  68.     set2 = [row for row in rows if
  69.             not split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja false
  70.     # return (set1, set2)
  71.     return (set_true, set_false)
  72.  
  73.  
  74. st, sf = divideset(my_data, 3, 20)
  75. print(sf)
  76. print(st)
  77.  
  78.  
  79. # Create counts of possible results (the last column of
  80. # each row is the result)
  81. def uniquecounts(rows):
  82.     results = {}
  83.     for row in rows:
  84.         # The result is the last column
  85.         r = row[-1]
  86.         results.setdefault(r, 0)
  87.         results[r] += 1
  88.  
  89.     return results
  90.  
  91.  
  92. print(uniquecounts(my_data))
  93. print(uniquecounts(st))
  94. print(uniquecounts(sf))
  95.  
  96.  
  97. # Probability that a randomly placed item will
  98. # be in the wrong category
  99.  
  100. def log2(x):
  101.     from math import log
  102.     l2 = log(x) / log(2)
  103.     return l2
  104.  
  105.  
  106. # Entropy is the sum of p(x)log(p(x)) across all
  107. # the different possible results
  108. def entropy(rows):
  109.     results = uniquecounts(rows)
  110.     # Now calculate the entropy
  111.     ent = 0.0
  112.     for r in results.keys():
  113.         p = float(results[r]) / len(rows)
  114.         ent = ent - p * log2(p)
  115.     return ent
  116.  
  117.  
  118. print(entropy(my_data), entropy(st), entropy(sf))
  119.  
  120.  
  121. # exit(0)
  122.  
  123.  
  124. def buildtree(rows, scoref=entropy):
  125.     if len(rows) == 0: return decisionnode()
  126.     current_score = scoref(rows)
  127.  
  128.     # Set up some variables to track the best criteria
  129.     best_gain = 0.0
  130.     best_column = -1
  131.     best_value = None
  132.     best_subsetf = None
  133.     best_subsett = None
  134.  
  135.     column_count = len(rows[0]) - 1
  136.     for col in range(column_count):
  137.         # Generate the list of different values in
  138.         # this column
  139.         column_values = set()
  140.         for row in rows:
  141.             column_values.add(row[col])
  142.         # Now try dividing the rows up for each value
  143.         # in this column
  144.         for value in column_values:
  145.             (set1, set2) = divideset(rows, col, value)
  146.  
  147.             # Information gain
  148.             p = float(len(set1)) / len(rows)
  149.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  150.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  151.                 best_gain = gain
  152.                 best_column = col
  153.                 best_value = value
  154.                 best_subsett = set1
  155.                 best_subsetf = set2
  156.                 # best_criteria = (col, value)
  157.                 # best_sets = (set1, set2)
  158.  
  159.     # Create the subbranches
  160.     if best_gain > 0:
  161.         trueBranch = buildtree(best_subsett, scoref)
  162.         falseBranch = buildtree(best_subsetf, scoref)
  163.         return decisionnode(col=best_column, value=best_value,
  164.                             tb=trueBranch, fb=falseBranch)
  165.     else:
  166.         return decisionnode(results=uniquecounts(rows))
  167.  
  168.  
  169. t = buildtree(my_data)
  170.  
  171.  
  172. def printtree(tree, indent=''):
  173.     # Is this a leaf node?
  174.     if tree.results != None:
  175.         print(indent + str(sorted(tree.results.items())))
  176.     else:
  177.         # Print the criteria
  178.         print(indent + str(tree.col) + ':' + str(tree.value) + '? ')
  179.         # Print the branches
  180.         print(indent + 'T->')
  181.         printtree(tree.tb, indent + '  ')
  182.         print(indent + 'F->')
  183.         printtree(tree.fb, indent + '  ')
  184.  
  185.  
  186. printtree(t)
  187.  
  188.  
  189. # exit(0)
  190.  
  191. def classify(observation, tree):
  192.     if tree.results != None:
  193.         return tree.results
  194.     else:
  195.         vrednost = observation[tree.col]
  196.         branch = None
  197.  
  198.         if isinstance(vrednost, int) or isinstance(vrednost, float):
  199.             if vrednost >= tree.value:
  200.                 branch = tree.tb
  201.             else:
  202.                 branch = tree.fb
  203.         else:
  204.             if vrednost == tree.value:
  205.                 branch = tree.tb
  206.             else:
  207.                 branch = tree.fb
  208.  
  209.         return classify(observation, branch)
  210.  
  211.  
  212. print(classify(['google', 'MK', 'no', 19, 'Unknown'], t))
  213. # for test_case in test_cases:
  214. #     print("Nepoznat slucaj:", test_case, " Klasifikacija: ", classify(test_case, t))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement