Advertisement
nanorocks

decision_tree_lab1_ex

May 10th, 2018
285
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.24 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2.  
  3. """
  4. Да се промени класата за дрво на одлука да чува и информација на кое ниво во дрвото се наоѓа јазолот.
  5. Потоа да се променат и функциите за градење и печатење на дрвото така што за секој јазол ќе се печати и нивото.
  6. Коренот е на нулто ниво. На излез со функцијата printTree треба да се испечати даденото тренинг множество.
  7. Прочитана инстанца од стандарден влез да се додаде на тренинг множеството и потоа да се истренира и испечати истото.
  8. """
  9.  
  10. trainingData=[['slashdot','USA','yes',18,'None'],
  11.         ['google','France','yes',23,'Premium'],
  12.         ['google','France','yes',23,'Basic'],
  13.         ['google','France','yes',23,'Basic'],
  14.         ['digg','USA','yes',24,'Basic'],
  15.         ['kiwitobes','France','yes',23,'Basic'],
  16.         ['google','UK','no',21,'Premium'],
  17.         ['(direct)','New Zealand','no',12,'None'],
  18.         ['(direct)','UK','no',21,'Basic'],
  19.         ['google','USA','no',24,'Premium'],
  20.         ['slashdot','France','yes',19,'None'],
  21.         ['digg','USA','no',18,'None'],
  22.         ['google','UK','no',18,'None'],
  23.         ['kiwitobes','UK','no',19,'None'],
  24.         ['digg','New Zealand','yes',12,'Basic'],
  25.         ['slashdot','UK','no',21,'None'],
  26.         ['google','UK','yes',18,'Basic'],
  27.         ['kiwitobes','France','yes',19,'Basic']]
  28.  
  29. class decisionnode:
  30.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None, lvl=0):
  31.         self.col = col
  32.         self.value = value
  33.         self.results = results
  34.         self.tb = tb
  35.         self.fb = fb
  36.         self.lvl = lvl
  37.  
  38. def sporedi_broj(row, column, value):
  39.     return row[column] >= value
  40.  
  41. def sporedi_string(row, column, value):
  42.     return row[column] == value
  43.  
  44. def divideset(rows, column, value):
  45.     split_function = None
  46.     # print(split_function)
  47.     if isinstance(value, int) or isinstance(value, float):
  48.         split_function = sporedi_broj
  49.     else:
  50.         split_function = sporedi_string
  51.     set_false = []
  52.     set_true = []
  53.     for row in rows:
  54.         if split_function(row, column, value):
  55.             set_true.append(row)
  56.         else:
  57.             set_false.append(row)
  58.     # print(len(set_true),len(set_false))
  59.     set1 = [row for row in rows if split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja true
  60.     set2 = [row for row in rows if not split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja false
  61.  
  62.     return (set_true, set_false)
  63.  
  64. def uniquecounts(rows):
  65.     results = {}
  66.     for row in rows:
  67.         # The result is the last column
  68.         r = row[-1]
  69.         results.setdefault(r, 0)
  70.         results[r] += 1
  71.     return results
  72.  
  73. def log2(x):
  74.     from math import log
  75.     l2 = log(x) / log(2)
  76.     return l2
  77.  
  78. # Entropy is the sum of p(x)log(p(x)) across all
  79. # the different possible results
  80. def entropy(rows):
  81.     results = uniquecounts(rows)
  82.     # Now calculate the entropy
  83.     ent = 0.0
  84.     for r in results.keys():
  85.         p = float(results[r]) / len(rows)
  86.         ent = ent - p * log2(p)
  87.     return ent
  88.  
  89. def buildtree(rows, scoref=entropy, l=-1):
  90.     if len(rows) == 0: return decisionnode()
  91.     current_score = scoref(rows)
  92.     # Set up some variables to track the best criteria
  93.     best_gain = 0.0
  94.     best_column = -1
  95.     best_value = None
  96.     best_subsetf = None
  97.     best_subsett = None
  98.     column_count = len(rows[0]) - 1
  99.     for col in range(column_count):
  100.         # Generate the list of different values in
  101.         # this column
  102.         column_values = set()
  103.         for row in rows:
  104.             column_values.add(row[col])
  105.         # Now try dividing the rows up for each value
  106.         # in this column
  107.         for value in column_values:
  108.             (set1, set2) = divideset(rows, col, value)
  109.  
  110.             # Information gain
  111.             p = float(len(set1)) / len(rows)
  112.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  113.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  114.                 best_gain = gain
  115.                 #best_column = col
  116.                 #best_value = value
  117.                 #best_subsett = set1
  118.                 #best_subsetf = set2
  119.                 best_criteria = (col, value)
  120.                 best_sets = (set1, set2)
  121.  
  122.     # Create the subbranches
  123.     if best_gain > 0:
  124.         l+=1
  125.         trueBranch = buildtree(best_sets[0], scoref, l=l)
  126.         falseBranch = buildtree(best_sets[1], scoref, l=l)
  127.         return decisionnode(col=best_criteria[0], value=best_criteria[1],
  128.                             tb=trueBranch, fb=falseBranch, lvl=l)
  129.     else:
  130.         return decisionnode(results=uniquecounts(rows))
  131.  
  132.  
  133. def printtree(tree, indent=''):
  134.     # Is this a leaf node?
  135.     if tree.results != None:
  136.         print(tree.results)
  137.     else:
  138.         # Print the criteria
  139.         lvl = tree.lvl
  140.         print(str(tree.col) + ':' + str(tree.value) + '? '+'Level='+str(tree.lvl))
  141.         # Print the branches
  142.         print(indent + 'T->'),
  143.         printtree(tree.tb, indent + '  ')
  144.         print(indent + 'F->'),
  145.         printtree(tree.fb, indent + '  ')
  146.  
  147.  
  148. def classify(observation, tree):
  149.     if tree.results != None:
  150.         return tree.results
  151.     else:
  152.         vrednost = observation[tree.col]
  153.         branch = None
  154.  
  155.         if isinstance(vrednost, int) or isinstance(vrednost, float):
  156.             if vrednost >= tree.value:
  157.                 branch = tree.tb
  158.             else:
  159.                 branch = tree.fb
  160.         else:
  161.             if vrednost == tree.value:
  162.                 branch = tree.tb
  163.             else:
  164.                 branch = tree.fb
  165.  
  166.         return classify(observation, branch)
  167.  
  168. if __name__ == '__main__':
  169.  
  170.     referrer = 'google'
  171.     location = 'UK'
  172.     readFAQ = 'no',
  173.     pagesVisited = 18
  174.     serviceChosen = 'None'
  175.  
  176.     tmp = [referrer, location, readFAQ, pagesVisited, serviceChosen]
  177.     trainingData.append(tmp)
  178.  
  179.     t = buildtree(trainingData)
  180.  
  181.     printtree(t)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement