Advertisement
nanorocks

decision_tree_lab2_ex

May 10th, 2018
361
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.27 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2.  
  3. """
  4. Да се промени функцијата за предвидување, така што таа ќе ја печати само класата која ја предвидува (а не речник како сега).
  5. Притоа да се проверува дали во листот има повеќе од една класа. Ако има само една класа тогаш се предвидува истата, но ако има
  6. повеќе од една треба да се испечати таа со најголем број на инстанци. Ако во листот има неколку класи со ист број на инстанци да
  7. се предвиде првата класа по азбучен ред.
  8. """
  9.  
  10.  
  11. trainingData=[['slashdot','USA','yes',18,'None'],
  12.         ['google','France','yes',23,'Premium'],
  13.         ['google','France','yes',23,'Basic'],
  14.         ['google','France','yes',23,'Basic'],
  15.         ['digg','USA','yes',24,'Basic'],
  16.         ['kiwitobes','France','yes',23,'Basic'],
  17.         ['google','UK','no',21,'Premium'],
  18.         ['(direct)','New Zealand','no',12,'None'],
  19.         ['(direct)','UK','no',21,'Basic'],
  20.         ['google','USA','no',24,'Premium'],
  21.         ['slashdot','France','yes',19,'None'],
  22.         ['digg','USA','no',18,'None'],
  23.         ['google','UK','no',18,'None'],
  24.         ['kiwitobes','UK','no',19,'None'],
  25.         ['digg','New Zealand','yes',12,'Basic'],
  26.         ['slashdot','UK','no',21,'None'],
  27.         ['google','UK','yes',18,'Basic'],
  28.         ['kiwitobes','France','yes',19,'Basic']]
  29.  
  30.  
  31. class decisionnode:
  32.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  33.         self.col = col
  34.         self.value = value
  35.         self.results = results
  36.         self.tb = tb
  37.         self.fb = fb
  38.  
  39.  
  40. def sporedi_broj(row, column, value):
  41.     return row[column] >= value
  42.  
  43.  
  44. def sporedi_string(row, column, value):
  45.     return row[column] == value
  46.  
  47. # Divides a set on a specific column. Can handle numeric
  48. # or nominal values
  49. def divideset(rows, column, value):
  50.     # Make a function that tells us if a row is in
  51.     # the first group (true) or the second group (false)
  52.     split_function = None
  53.     if isinstance(value, int) or isinstance(value, float):  # ako vrednosta so koja sporeduvame e od tip int ili float
  54.         # split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  55.         split_function = sporedi_broj
  56.     else:
  57.         # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  58.         split_function = sporedi_string
  59.  
  60.     # Divide the rows into two sets and return them
  61.     # set1=[row for row in rows if split_function(row)]  # za sekoj row od rows za koj split_function vrakja true
  62.     # set2=[row for row in rows if not split_function(row)] # za sekoj row od rows za koj split_function vrakja false
  63.     set1 = [row for row in rows if
  64.             split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja true
  65.     set2 = [row for row in rows if
  66.             not split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja false
  67.     return (set1, set2)
  68.  
  69.  
  70. # Create counts of possible results (the last column of
  71. # each row is the result)
  72. def uniquecounts(rows):
  73.     results = {}
  74.     for row in rows:
  75.         # The result is the last column
  76.         r = row[len(row) - 1]
  77.         if r not in results: results[r] = 0
  78.         results[r] += 1
  79.     return results
  80.  
  81.  
  82. # Probability that a randomly placed item will
  83. # be in the wrong category
  84. def giniimpurity(rows):
  85.     total = len(rows)
  86.     counts = uniquecounts(rows)
  87.     imp = 0
  88.     for k1 in counts:
  89.         p1 = float(counts[k1]) / total
  90.         for k2 in counts:
  91.             if k1 == k2: continue
  92.             p2 = float(counts[k2]) / total
  93.             imp += p1 * p2
  94.     return imp
  95.  
  96.  
  97. # Entropy is the sum of p(x)log(p(x)) across all
  98. # the different possible results
  99. def entropy(rows):
  100.     from math import log
  101.     log2 = lambda x: log(x) / log(2)
  102.     results = uniquecounts(rows)
  103.     # Now calculate the entropy
  104.     ent = 0.0
  105.     for r in results.keys():
  106.         p = float(results[r]) / len(rows)
  107.         ent = ent - p * log2(p)
  108.     return ent
  109.  
  110.  
  111. def buildtree(rows, scoref=entropy):
  112.     if len(rows) == 0: return decisionnode()
  113.     current_score = scoref(rows)
  114.  
  115.     # Set up some variables to track the best criteria
  116.     best_gain = 0.0
  117.     best_criteria = None
  118.     best_sets = None
  119.  
  120.     column_count = len(rows[0]) - 1
  121.     for col in range(0, column_count):
  122.         # Generate the list of different values in
  123.         # this column
  124.         column_values = {}
  125.         for row in rows:
  126.             column_values[row[col]] = 1
  127.             #print
  128.         # Now try dividing the rows up for each value
  129.         # in this column
  130.         for value in column_values.keys():
  131.             (set1, set2) = divideset(rows, col, value)
  132.  
  133.             # Information gain
  134.             p = float(len(set1)) / len(rows)
  135.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  136.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  137.                 best_gain = gain
  138.                 best_criteria = (col, value)
  139.                 best_sets = (set1, set2)
  140.  
  141.     # Create the subbranches
  142.     if best_gain > 0:
  143.         trueBranch = buildtree(best_sets[0])
  144.         falseBranch = buildtree(best_sets[1])
  145.         return decisionnode(col=best_criteria[0], value=best_criteria[1],
  146.                             tb=trueBranch, fb=falseBranch)
  147.     else:
  148.         return decisionnode(results=uniquecounts(rows))
  149.  
  150.  
  151. def printtree(tree, indent=''):
  152.     # Is this a leaf node?
  153.     if tree.results != None:
  154.         print str(tree.results)
  155.     else:
  156.         # Print the criteria
  157.         print str(tree.col) + ':' + str(tree.value) + '? '
  158.         # Print the branches
  159.         print indent + 'T->',
  160.         printtree(tree.tb, indent + '  ')
  161.         print indent + 'F->',
  162.         printtree(tree.fb, indent + '  ')
  163.  
  164.  
  165. def classify(observation, tree):
  166.     if tree.results != None:
  167.         return tree.results
  168.     else:
  169.         vrednost = observation[tree.col]
  170.         branch = None
  171.  
  172.         if isinstance(vrednost, int) or isinstance(vrednost, float):
  173.             if vrednost >= tree.value:
  174.                 branch = tree.tb
  175.             else:
  176.                 branch = tree.fb
  177.         else:
  178.             if vrednost == tree.value:
  179.                 branch = tree.tb
  180.             else:
  181.                 branch = tree.fb
  182.  
  183.         return classify(observation, branch)
  184.  
  185.  
  186. if __name__=='__main__':
  187.  
  188.     t = buildtree(trainingData)
  189.  
  190.     #printtree(t)
  191.  
  192.     referrer='digg'
  193.     location='Macedonia'
  194.     readFAQ='no'
  195.     pagesVisited=24
  196.     serviceChosen='Unknown'
  197.  
  198.     referrer = input()
  199.     location = input()
  200.     readFAQ = input()
  201.     pagesVisited = input()
  202.     serviceChosen = input()
  203.  
  204.     tmp = [referrer, location, readFAQ, pagesVisited, serviceChosen]
  205.     c = classify(tmp,t)
  206.     max = 0
  207.     max_val = 0
  208.     l = []
  209.     for k,v in c.items():
  210.         if v>max:
  211.             max = v
  212.     #print max
  213.     for k,v in c.items():
  214.         if v==max:
  215.             l.append((k,v))
  216.     l = sorted(l)
  217.     print l[0][0]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement