Advertisement
nanorocks

decision_tree_exam_2018

May 22nd, 2018
725
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.18 KB | None | 0 0
  1. trainingData=[['twitter','USA','yes',18,'None'],
  2.         ['google','France','yes',23,'Premium'],
  3.         ['google','France','no',26,'Basic'],
  4.         ['google','Macedonia','yes',13,'None'],
  5.         ['pinterest','USA','yes',24,'Basic'],
  6.         ['bing','France','yes',23,'Basic'],
  7.         ['google','UK','no',21,'Premium'],
  8.         ['facebook','New Zealand','no',12,'None'],
  9.         ['facebook','UK','no',21,'Basic'],
  10.         ['google','USA','no',24,'Premium'],
  11.         ['twitter','France','yes',19,'None'],
  12.         ['pinterest','USA','no',18,'None'],
  13.         ['google','UK','no',18,'None'],
  14.         ['bing','UK','yes',19,'Premium'],
  15.         ['bing','Macedonia','no',10,'None'],
  16.         ['facebook','Macedonia','no',16,'Basic'],
  17.         ['bing','UK','no',19,'Basic'],
  18.         ['pinterest','Germany','no',2,'None'],
  19.         ['pinterest','USA','yes',12,'Basic'],
  20.         ['twitter','UK','no',21,'None'],
  21.         ['twitter','UK','yes',26,'Premium'],
  22.         ['google','UK','yes',18,'Basic'],
  23.         ['bing','France','yes',19,'Basic']]
  24.  
  25. test_cases=[['google','MK','no',24,'Unknown'],
  26.             ['google','MK','no',15,'Unknown'],
  27.             ['pinterest','UK','yes',21,'Unknown'],
  28.             ['pinterest','UK','no',25,'Unknown']]
  29.  
  30. # trainingData=[line.split('\t') for line in file('decision_tree_example.txt')]
  31.  
  32. class decisionnode:
  33.       def __init__(self,col=-1,value=None,results=None,tb=None,fb=None,level=0):
  34.          self.level=level
  35.          self.col=col
  36.          self.value=value
  37.          self.results=results
  38.          self.tb=tb
  39.          self.fb=fb
  40.          self.level=level
  41.  
  42. def sporedi_broj(row,column,value):
  43.   return row[column]>=value
  44.  
  45. def sporedi_string(row,column,value):
  46.   return row[column]==value
  47.  
  48. # Divides a set on a specific column. Can handle numeric
  49. # or nominal values
  50. def divideset(rows,column,value):
  51.     # Make a function that tells us if a row is in
  52.     # the first group (true) or the second group (false)
  53.     split_function=None
  54.     if isinstance(value,int) or isinstance(value,float): # ako vrednosta so koja sporeduvame e od tip int ili float
  55.        #split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  56.        split_function=sporedi_broj
  57.     else:
  58.        # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  59.        split_function=sporedi_string
  60.  
  61.     # Divide the rows into two sets and return them
  62.     # set1=[row for row in rows if split_function(row)]  # za sekoj row od rows za koj split_function vrakja true
  63.     # set2=[row for row in rows if not split_function(row)] # za sekoj row od rows za koj split_function vrakja false
  64.     set1=[row for row in rows if split_function(row,column,value)]  # za sekoj row od rows za koj split_function vrakja true
  65.     set2=[row for row in rows if not split_function(row,column,value)] # za sekoj row od rows za koj split_function vrakja false
  66.    
  67.     return (set1,set2)
  68.  
  69. # Divides a set on a specific column. Can handle numeric
  70. # or nominal values
  71. def divideset2(rows,column,value):
  72.     # Make a function that tells us if a row is in
  73.     # the first group (true) or the second group (false)
  74.     split_function=None
  75.     if isinstance(value,int) or isinstance(value,float): # ako vrednosta so koja sporeduvame e od tip int ili float
  76.        #split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  77.        split_function=sporedi_broj
  78.     else:
  79.        # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  80.        split_function=sporedi_string
  81.  
  82.     # Divide the rows into two sets and return them
  83.     # set1=[row for row in rows if split_function(row)]  # za sekoj row od rows za koj split_function vrakja true
  84.     # set2=[row for row in rows if not split_function(row)] # za sekoj row od rows za koj split_function vrakja false
  85.     set1=[]
  86.     set2=[]
  87.     for row in rows:
  88.       if split_function(row,column,value):
  89.         set1.append(row)
  90.       else:
  91.         set2.append(row)
  92.     return (set1,set2)
  93.  
  94.  
  95.  
  96. # Create counts of possible results (the last column of
  97. # each row is the result)
  98. def uniquecounts(rows):
  99.   results={}
  100.   for row in rows:
  101.      # The result is the last column
  102.      r=row[len(row)-1]
  103.      if r not in results: results[r]=0
  104.      results[r]+=1
  105.   return results
  106.  
  107.  
  108. # Entropy is the sum of p(x)log(p(x)) across all
  109. # the different possible results
  110. def entropy(rows):
  111.       from math import log
  112.       log2=lambda x:log(x)/log(2)
  113.       results=uniquecounts(rows)
  114.       # Now calculate the entropy
  115.       ent=0.0
  116.       for r in results.keys():
  117.             p=float(results[r])/len(rows)
  118.             ent=ent-p*log2(p)
  119.       return ent
  120.  
  121. def buildtree(rows,scoref=entropy,level=0):
  122.       if len(rows)==0: return decisionnode()
  123.       current_score=scoref(rows)
  124.  
  125.       # Set up some variables to track the best criteria
  126.       best_gain=0.0
  127.       best_criteria=None
  128.       best_sets=None
  129.      
  130.       column_count=len(rows[0])-1
  131.       for col in range(0,column_count):
  132.             # Generate the list of different values in
  133.             # this column
  134.             column_values={}
  135.             for row in rows:
  136.                   column_values[row[col]]=1
  137.                   # print row[col]
  138.             # print
  139.             # print column_values
  140.             # Now try dividing the rows up for each value
  141.             # in this column
  142.             for value in column_values.keys():
  143.                   (set1,set2)=divideset(rows,col,value)
  144.                  
  145.                   # Information gain
  146.                   p=float(len(set1))/len(rows)
  147.                   gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
  148.                   # print set1, set2, gain
  149.                   if gain>best_gain and len(set1)>0 and len(set2)>0:
  150.                         best_gain=gain
  151.                         best_criteria=(col,value)
  152.                         best_sets=(set1,set2)
  153.      
  154.       # Create the subbranches
  155.       if best_gain>0:
  156.             trueBranch=buildtree(best_sets[0],level=level+1)
  157.             falseBranch=buildtree(best_sets[1], level=level+1)
  158.             return decisionnode(col=best_criteria[0],value=best_criteria[1],
  159.                             tb=trueBranch, fb=falseBranch, level=level)
  160.       else:
  161.             return decisionnode(results=uniquecounts(rows))
  162.  
  163. def printtree(tree,indent=''):
  164.       # Is this a leaf node?
  165.       if tree.results!=None:
  166.             print str(tree.results)
  167.       else:
  168.             # Print the criteria
  169.             print str(tree.col)+':'+str(tree.value)+'?' + ' Level='+str(tree.level)
  170.             # Print the branches
  171.             print indent+'T->',
  172.             printtree(tree.tb,indent+'  ')
  173.             print indent+'F->',
  174.             printtree(tree.fb,indent+'  ')
  175.  
  176. def classify(observation,tree):
  177.     if tree.results!=None:
  178.         results=[(value,key) for key,value in tree.results.items()]
  179.         results.sort()
  180.         return results[0][1]
  181.     else:
  182.         vrednost=observation[tree.col]
  183.         branch=None
  184.  
  185.         if isinstance(vrednost,int) or isinstance(vrednost,float):
  186.             if vrednost>=tree.value: branch=tree.tb
  187.             else: branch=tree.fb
  188.         else:
  189.            if vrednost==tree.value: branch=tree.tb
  190.            else: branch=tree.fb
  191.  
  192.         return classify(observation,branch)
  193.        
  194.  
  195. def classify2(observation,tree):
  196.     if tree.results!=None:
  197.         results=[(value,key) for key,value in tree.results.items()]
  198.         results.sort()
  199.         return results[0][1]
  200.     else:
  201.         vrednost=observation[tree.col]
  202.         branch=None
  203.  
  204.         if isinstance(vrednost,int) or isinstance(vrednost,float):
  205.             if vrednost>=tree.value: branch=tree.tb
  206.             else: branch=tree.fb
  207.         else:
  208.            if vrednost==tree.value: branch=tree.tb
  209.            else: branch=tree.fb
  210.  
  211.         return classify2(observation,branch)
  212.  
  213. def classify3(observation,tree):
  214.     if tree.results!=None:
  215.         results=[(value,key) for key,value in tree.results.items()]
  216.         results.sort()
  217.         return results[0][1]
  218.     else:
  219.         vrednost=observation[tree.col]
  220.         branch=None
  221.         granka='True branch'
  222.         if isinstance(vrednost,int) or isinstance(vrednost,float):
  223.             if vrednost>=tree.value:
  224.                 branch=tree.tb
  225.             else:
  226.                 branch=tree.fb
  227.                 granka='False branch'
  228.         else:
  229.            if vrednost==tree.value:
  230.                branch=tree.tb
  231.            else:
  232.                branch=tree.fb
  233.                granka='False branch'
  234.         print 'Sporeduvam kolona i vrednost', (tree.col, tree.value)
  235.         print 'Tekovna vrednost:', vrednost
  236.         print 'Sledna granka:',granka
  237.         print 'Preostanata granka za izminuvanje:'
  238.         printtree(branch)
  239.         print
  240.         return classify3(observation,branch)
  241.  
  242. if __name__ == "__main__":
  243.     referrer=input()
  244.     location=input()
  245.     readFAQ=input()
  246.     pagesVisited=input()
  247.     serviceChosen='Unknown'
  248.  
  249.  
  250.     testCase=[referrer, location, readFAQ, pagesVisited, serviceChosen]
  251.  
  252.     t=buildtree(trainingData)
  253.     printtree(t)
  254.     print classify3(testCase,t)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement