Advertisement
nanorocks

lab_python_SNZ_2_decision_trees

Nov 5th, 2017
353
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.19 KB | None | 0 0
  1. """
  2.  
  3. Задача 2 Problem 2 (2 / 8)
  4. Да се промени функцијата за предвидување, така што таа ќе ја печати само класата која ја предвидува (а не речник како сега). Притоа да се проверува дали во листот има повеќе од една класа. Ако има само една класа тогаш се предвидува истата, но ако има повеќе од една треба да се испечати таа со најголем број на инстанци. Ако во листот има неколку класи со ист број на инстанци да се предвиде првата класа по азбучен ред.
  5.  
  6. """
  7.  
  8. trainingData=[['slashdot','USA','yes',18,'None'],
  9.         ['google','France','yes',23,'Premium'],
  10.         ['google','France','yes',23,'Basic'],
  11.         ['google','France','yes',23,'Basic'],
  12.         ['digg','USA','yes',24,'Basic'],
  13.         ['kiwitobes','France','yes',23,'Basic'],
  14.         ['google','UK','no',21,'Premium'],
  15.         ['(direct)','New Zealand','no',12,'None'],
  16.         ['(direct)','UK','no',21,'Basic'],
  17.         ['google','USA','no',24,'Premium'],
  18.         ['slashdot','France','yes',19,'None'],
  19.         ['digg','USA','no',18,'None'],
  20.         ['google','UK','no',18,'None'],
  21.         ['kiwitobes','UK','no',19,'None'],
  22.         ['digg','New Zealand','yes',12,'Basic'],
  23.         ['slashdot','UK','no',21,'None'],
  24.         ['google','UK','yes',18,'Basic'],
  25.         ['kiwitobes','France','yes',19,'Basic']]
  26.  
  27.  
  28. class decisionnode:
  29.     def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
  30.         self.col = col
  31.         self.value = value
  32.         self.results = results
  33.         self.tb = tb
  34.         self.fb = fb
  35.  
  36.  
  37. def sporedi_broj(row, column, value):
  38.     return row[column] >= value
  39.  
  40.  
  41. def sporedi_string(row, column, value):
  42.     return row[column] == value
  43.  
  44.  
  45. # Divides a set on a specific column. Can handle numeric
  46. # or nominal values
  47. def divideset(rows, column, value):
  48.     # Make a function that tells us if a row is in
  49.     # the first group (true) or the second group (false)
  50.     split_function = None
  51.     if isinstance(value, int) or isinstance(value, float):  # ako vrednosta so koja sporeduvame e od tip int ili float
  52.         # split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  53.         split_function = sporedi_broj
  54.     else:
  55.         # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  56.         split_function = sporedi_string
  57.  
  58.     # Divide the rows into two sets and return them
  59.     # set1=[row for row in rows if split_function(row)]  # za sekoj row od rows za koj split_function vrakja true
  60.     # set2=[row for row in rows if not split_function(row)] # za sekoj row od rows za koj split_function vrakja false
  61.     set1 = [row for row in rows if
  62.             split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja true
  63.     set2 = [row for row in rows if
  64.             not split_function(row, column, value)]  # za sekoj row od rows za koj split_function vrakja false
  65.     return (set1, set2)
  66.  
  67.  
  68. # Create counts of possible results (the last column of
  69. # each row is the result)
  70. def uniquecounts(rows):
  71.     results = {}
  72.     for row in rows:
  73.         # The result is the last column
  74.         r = row[len(row) - 1]
  75.         if r not in results: results[r] = 0
  76.         results[r] += 1
  77.     return results
  78.  
  79.  
  80. # Probability that a randomly placed item will
  81. # be in the wrong category
  82. def giniimpurity(rows):
  83.     total = len(rows)
  84.     counts = uniquecounts(rows)
  85.     imp = 0
  86.     for k1 in counts:
  87.         p1 = float(counts[k1]) / total
  88.         for k2 in counts:
  89.             if k1 == k2: continue
  90.             p2 = float(counts[k2]) / total
  91.             imp += p1 * p2
  92.     return imp
  93.  
  94.  
  95. # Entropy is the sum of p(x)log(p(x)) across all
  96. # the different possible results
  97. def entropy(rows):
  98.     from math import log
  99.     log2 = lambda x: log(x) / log(2)
  100.     results = uniquecounts(rows)
  101.     # Now calculate the entropy
  102.     ent = 0.0
  103.     for r in results.keys():
  104.         p = float(results[r]) / len(rows)
  105.         ent = ent - p * log2(p)
  106.     return ent
  107.  
  108.  
  109. def buildtree(rows, scoref=entropy):
  110.     if len(rows) == 0: return decisionnode()
  111.     current_score = scoref(rows)
  112.  
  113.     # Set up some variables to track the best criteria
  114.     best_gain = 0.0
  115.     best_criteria = None
  116.     best_sets = None
  117.  
  118.     column_count = len(rows[0]) - 1
  119.     for col in range(0, column_count):
  120.         # Generate the list of different values in
  121.         # this column
  122.         column_values = {}
  123.         for row in rows:
  124.             column_values[row[col]] = 1
  125.         # Now try dividing the rows up for each value
  126.         # in this column
  127.         for value in column_values.keys():
  128.             (set1, set2) = divideset(rows, col, value)
  129.  
  130.             # Information gain
  131.             p = float(len(set1)) / len(rows)
  132.             gain = current_score - p * scoref(set1) - (1 - p) * scoref(set2)
  133.             if gain > best_gain and len(set1) > 0 and len(set2) > 0:
  134.                 best_gain = gain
  135.                 best_criteria = (col, value)
  136.                 best_sets = (set1, set2)
  137.  
  138.     # Create the subbranches
  139.     if best_gain > 0:
  140.         trueBranch = buildtree(best_sets[0])
  141.         falseBranch = buildtree(best_sets[1])
  142.         return decisionnode(col=best_criteria[0], value=best_criteria[1],
  143.                             tb=trueBranch, fb=falseBranch)
  144.     else:
  145.         return decisionnode(results=uniquecounts(rows))
  146.  
  147.  
  148. def printtree(tree, indent=''):
  149.     # Is this a leaf node?
  150.     if tree.results != None:
  151.         print str(tree.results)
  152.     else:
  153.         # Print the criteria
  154.         print str(tree.col) + ':' + str(tree.value) + '? '
  155.         # Print the branches
  156.         print indent + 'T->',
  157.         printtree(tree.tb, indent + '  ')
  158.         print indent + 'F->',
  159.         printtree(tree.fb, indent + '  ')
  160.  
  161.  
  162. def classify(observation, tree):
  163.     if tree.results != None:
  164.         lista=[]
  165.         maxi=0
  166.         # print tree.results
  167.         for k in tree.results:
  168.             if tree.results[k]>=maxi:
  169.                 maxi=tree.results[k]
  170.         for k in tree.results:
  171.             if tree.results[k] == maxi:
  172.                 lista.append(k)
  173.         lista.sort()
  174.         return lista[0]
  175.     else:
  176.         vrednost = observation[tree.col]
  177.         branch = None
  178.  
  179.         if isinstance(vrednost, int) or isinstance(vrednost, float):
  180.             if vrednost >= tree.value:
  181.                 branch = tree.tb
  182.             else:
  183.                 branch = tree.fb
  184.         else:
  185.             if vrednost == tree.value:
  186.                 branch = tree.tb
  187.             else:
  188.                 branch = tree.fb
  189.  
  190.         return classify(observation, branch)
  191.  
  192.  
  193. if __name__ == "__main__":
  194.  
  195.    
  196.     referrer=input()
  197.     location=input()
  198.     readFAQ=input()
  199.     pagesVisited=input()
  200.     serviceChosen=input()
  201.  
  202.     testCase=[referrer, location, readFAQ, pagesVisited, serviceChosen]
  203.  
  204.     t=buildtree(trainingData)
  205.     print classify(testCase,t)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement