Advertisement
Pesovska

Drva2

May 24th, 2017
235
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.12 KB | None | 0 0
  1. trainingData=[['slashdot','USA','yes',18,'None'],
  2.         ['google','France','yes',23,'Premium'],
  3.         ['google','France','yes',23,'Basic'],
  4.         ['google','France','yes',23,'Basic'],
  5.         ['digg','USA','yes',24,'Basic'],
  6.         ['kiwitobes','France','yes',23,'Basic'],
  7.         ['google','UK','no',21,'Premium'],
  8.         ['(direct)','New Zealand','no',12,'None'],
  9.         ['(direct)','UK','no',21,'Basic'],
  10.         ['google','USA','no',24,'Premium'],
  11.         ['slashdot','France','yes',19,'None'],
  12.         ['digg','USA','no',18,'None'],
  13.         ['google','UK','no',18,'None'],
  14.         ['kiwitobes','UK','no',19,'None'],
  15.         ['digg','New Zealand','yes',12,'Basic'],
  16.         ['slashdot','UK','no',21,'None'],
  17.         ['google','UK','yes',18,'Basic'],
  18.         ['kiwitobes','France','yes',19,'Basic']]
  19.  
  20. # my_data=[line.split('\t') for line in file('decision_tree_example.txt')]
  21.  
  22. class decisionnode:
  23.       def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
  24.          self.col=col
  25.          self.value=value
  26.          self.results=results
  27.          self.tb=tb
  28.          self.fb=fb
  29.  
  30. def sporedi_broj(row,column,value):
  31.   return row[column]>=value
  32.  
  33. def sporedi_string(row,column,value):
  34.   return row[column]==value
  35.  
  36. # Divides a set on a specific column. Can handle numeric
  37. # or nominal values
  38. def divideset(rows,column,value):
  39.     # Make a function that tells us if a row is in
  40.     # the first group (true) or the second group (false)
  41.     split_function=None
  42.     if isinstance(value,int) or isinstance(value,float): # ako vrednosta so koja sporeduvame e od tip int ili float
  43.        #split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  44.        split_function=sporedi_broj
  45.     else:
  46.        # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  47.        split_function=sporedi_string
  48.  
  49.     # Divide the rows into two sets and return them
  50.     # set1=[row for row in rows if split_function(row)]  # za sekoj row od rows za koj split_function vrakja true
  51.     # set2=[row for row in rows if not split_function(row)] # za sekoj row od rows za koj split_function vrakja false
  52.     set1=[row for row in rows if split_function(row,column,value)]  # za sekoj row od rows za koj split_function vrakja true
  53.     set2=[row for row in rows if not split_function(row,column,value)] # za sekoj row od rows za koj split_function vrakja false
  54.     return (set1,set2)
  55.  
  56. # Create counts of possible results (the last column of
  57. # each row is the result)
  58. def uniquecounts(rows):
  59.   results={}
  60.   for row in rows:
  61.      # The result is the last column
  62.      r=row[len(row)-1]
  63.      if r not in results: results[r]=0
  64.      results[r]+=1
  65.   return results
  66.  
  67. # Probability that a randomly placed item will
  68. # be in the wrong category
  69. def giniimpurity(rows):
  70.       total=len(rows)
  71.       counts=uniquecounts(rows)
  72.       imp=0
  73.       for k1 in counts:
  74.             p1=float(counts[k1])/total
  75.             for k2 in counts:
  76.                   if k1==k2: continue
  77.                   p2=float(counts[k2])/total
  78.                   imp+=p1*p2
  79.       return imp
  80.  
  81.  
  82. # Entropy is the sum of p(x)log(p(x)) across all
  83. # the different possible results
  84. def entropy(rows):
  85.       from math import log
  86.       log2=lambda x:log(x)/log(2)
  87.       results=uniquecounts(rows)
  88.       # Now calculate the entropy
  89.       ent=0.0
  90.       for r in results.keys():
  91.             p=float(results[r])/len(rows)
  92.             ent=ent-p*log2(p)
  93.       return ent
  94.  
  95. def buildtree(rows,scoref=entropy):
  96.       if len(rows)==0: return decisionnode()
  97.       current_score=scoref(rows)
  98.  
  99.       # Set up some variables to track the best criteria
  100.       best_gain=0.0
  101.       best_criteria=None
  102.       best_sets=None
  103.  
  104.       column_count=len(rows[0])-1
  105.       for col in range(0,column_count):
  106.             # Generate the list of different values in
  107.             # this column
  108.             column_values={}
  109.             for row in rows:
  110.                   column_values[row[col]]=1
  111.                   # print
  112.             # Now try dividing the rows up for each value
  113.             # in this column
  114.             for value in column_values.keys():
  115.                   (set1,set2)=divideset(rows,col,value)
  116.  
  117.                   # Information gain
  118.                   p=float(len(set1))/len(rows)
  119.                   gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
  120.                   if gain>best_gain and len(set1)>0 and len(set2)>0:
  121.                         best_gain=gain
  122.                         best_criteria=(col,value)
  123.                         best_sets=(set1,set2)
  124.  
  125.       # Create the subbranches
  126.       if best_gain>0:
  127.             trueBranch=buildtree(best_sets[0])
  128.             falseBranch=buildtree(best_sets[1])
  129.             return decisionnode(col=best_criteria[0],value=best_criteria[1],
  130.                             tb=trueBranch, fb=falseBranch)
  131.       else:
  132.             return decisionnode(results=uniquecounts(rows))
  133.  
  134. def printtree(tree,indent=''):
  135.       # Is this a leaf node?
  136.       if tree.results!=None:
  137.             print str(tree.results)
  138.       else:
  139.             # Print the criteria
  140.             print str(tree.col)+':'+str(tree.value)+'? '
  141.             # Print the branches
  142.             print indent+'T->',
  143.             printtree(tree.tb,indent+'  ')
  144.             print indent+'F->',
  145.             printtree(tree.fb,indent+'  ')
  146.  
  147.  
  148. def classify(observation,tree):
  149.     if tree.results!=None:
  150.         if len(tree.results.keys()) == 1:
  151.             return tree.results.keys()[0]
  152.         else:
  153.             max=0
  154.             best=None
  155.             for classAttr in tree.results.keys():
  156.                 if tree.results[classAttr] > max:
  157.                     max = tree.results[classAttr]
  158.                     best = classAttr
  159.                 elif tree.results[classAttr] == max and classAttr < best:
  160.                     best = classAttr
  161.             return best
  162.     else:
  163.         vrednost=observation[tree.col]
  164.         branch=None
  165.  
  166.         if isinstance(vrednost,int) or isinstance(vrednost,float):
  167.             if vrednost>=tree.value: branch=tree.tb
  168.             else: branch=tree.fb
  169.         else:
  170.            if vrednost==tree.value: branch=tree.tb
  171.            else: branch=tree.fb
  172.  
  173.         return classify(observation,branch)
  174.  
  175.  
  176. # (s1,s2)=divideset(my_data,2,'yes')
  177. # (sa1,sa2)=divideset(my_data,0,'google')
  178. # (sb1,sb2)=divideset(my_data,1,'USA')
  179. #
  180. # print len(s1),len(s2),uniquecounts(my_data)
  181. # print entropy(my_data),giniimpurity(my_data)
  182. # print entropy(s1),giniimpurity(s1)
  183. # print entropy(s2),giniimpurity(s2)
  184. # t= buildtree(my_data)
  185. # # drawtree(t)
  186. # printtree(t)
  187. # for test_case in test_cases:
  188. #     print "Nepoznat slucaj:", test_case, " Klasifikacija: ", classify(test_case,t)
  189.  
  190.  
  191. if __name__ == "__main__":
  192.     # referrer='slashdot'
  193.     # location='UK'
  194.     # readFAQ='no'
  195.     # pagesVisited=21
  196.     # serviceChosen='Unknown'
  197.  
  198.     referrer=input()
  199.     location=input()
  200.     readFAQ=input()
  201.     pagesVisited=input()
  202.     serviceChosen=input()
  203.  
  204.     testCase=[referrer, location, readFAQ, pagesVisited, serviceChosen]
  205.  
  206.     t=buildtree(trainingData)
  207.     print classify(testCase,t)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement