Advertisement
lameski

Python decision_tree prob1

Jun 12th, 2017
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.12 KB | None | 0 0
  1. trainingData=[['slashdot','USA','yes',18,'None'],
  2. ['google','France','yes',23,'Premium'],
  3. ['google','France','yes',23,'Basic'],
  4. ['google','France','yes',23,'Basic'],
  5. ['digg','USA','yes',24,'Basic'],
  6. ['kiwitobes','France','yes',23,'Basic'],
  7. ['google','UK','no',21,'Premium'],
  8. ['(direct)','New Zealand','no',12,'None'],
  9. ['(direct)','UK','no',21,'Basic'],
  10. ['google','USA','no',24,'Premium'],
  11. ['slashdot','France','yes',19,'None'],
  12. ['digg','USA','no',18,'None'],
  13. ['google','UK','no',18,'None'],
  14. ['kiwitobes','UK','no',19,'None'],
  15. ['digg','New Zealand','yes',12,'Basic'],
  16. ['slashdot','UK','no',21,'None'],
  17. ['google','UK','yes',18,'Basic'],
  18. ['kiwitobes','France','yes',19,'Basic']]
  19.  
  20.  
  21.  
  22.  
  23. class decisionnode:
  24. def __init__(self,col=-1,value=None,results=None,tb=None,fb=None, lvl=0):
  25. self.col=col
  26. self.value=value
  27. self.results=results
  28. self.tb=tb
  29. self.fb=fb
  30. self.lvl = lvl
  31.  
  32. def sporedi_broj(row,column,value):
  33. return row[column]>=value
  34.  
  35. def sporedi_string(row,column,value):
  36. return row[column]==value
  37.  
  38. # Divides a set on a specific column. Can handle numeric
  39. # or nominal values
  40. def divideset(rows,column,value):
  41. # Make a function that tells us if a row is in
  42. # the first group (true) or the second group (false)
  43. split_function=None
  44. if isinstance(value,int) or isinstance(value,float): # ako vrednosta so koja sporeduvame e od tip int ili float
  45. #split_function=lambda row:row[column]>=value # togas vrati funkcija cij argument e row i vrakja vrednost true ili false
  46. split_function=sporedi_broj
  47. else:
  48. # split_function=lambda row:row[column]==value # ako vrednosta so koja sporeduvame e od drug tip (string)
  49. split_function=sporedi_string
  50.  
  51. # Divide the rows into two sets and return them
  52. # set1=[row for row in rows if split_function(row)] # za sekoj row od rows za koj split_function vrakja true
  53. # set2=[row for row in rows if not split_function(row)] # za sekoj row od rows za koj split_function vrakja false
  54. set1=[row for row in rows if split_function(row,column,value)] # za sekoj row od rows za koj split_function vrakja true
  55. set2=[row for row in rows if not split_function(row,column,value)] # za sekoj row od rows za koj split_function vrakja false
  56. return (set1,set2)
  57.  
  58. # Create counts of possible results (the last column of
  59. # each row is the result)
  60. def uniquecounts(rows):
  61. results={}
  62. for row in rows:
  63. # The result is the last column
  64. r=row[len(row)-1]
  65. if r not in results: results[r]=0
  66. results[r]+=1
  67. return results
  68.  
  69. # Probability that a randomly placed item will
  70. # be in the wrong category
  71. def giniimpurity(rows):
  72. total=len(rows)
  73. counts=uniquecounts(rows)
  74. imp=0
  75. for k1 in counts:
  76. p1=float(counts[k1])/total
  77. for k2 in counts:
  78. if k1==k2: continue
  79. p2=float(counts[k2])/total
  80. imp+=p1*p2
  81. return imp
  82.  
  83.  
  84. # Entropy is the sum of p(x)log(p(x)) across all
  85. # the different possible results
  86. def entropy(rows):
  87. from math import log
  88. log2=lambda x:log(x)/log(2)
  89. results=uniquecounts(rows)
  90. # Now calculate the entropy
  91. ent=0.0
  92. for r in results.keys():
  93. p=float(results[r])/len(rows)
  94. ent=ent-p*log2(p)
  95. return ent
  96.  
  97. def buildtree(rows,scoref=entropy):
  98. if len(rows)==0: return decisionnode()
  99. current_score=scoref(rows)
  100.  
  101. # Set up some variables to track the best criteria
  102. best_gain=0.0
  103. best_criteria=None
  104. best_sets=None
  105. n=0
  106. column_count=len(rows[0])-1
  107. for col in range(0,column_count):
  108. # Generate the list of different values in
  109. # this column
  110. column_values={}
  111. for row in rows:
  112. column_values[row[col]]=1
  113. print
  114. # Now try dividing the rows up for each value
  115. # in this column
  116. for value in column_values.keys():
  117. (set1,set2)=divideset(rows,col,value)
  118.  
  119. # Information gain
  120. p=float(len(set1))/len(rows)
  121. gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
  122. if gain>best_gain and len(set1)>0 and len(set2)>0:
  123. best_gain=gain
  124. best_criteria=(col,value)
  125. best_sets=(set1,set2)
  126.  
  127. # Create the subbranches
  128. if best_gain>0:
  129.  
  130. trueBranch=buildtree(best_sets[0])
  131. falseBranch=buildtree(best_sets[1])
  132.  
  133. return decisionnode(col=best_criteria[0],value=best_criteria[1],
  134. tb=trueBranch, fb=falseBranch, lvl=n)
  135. else:
  136. n+=1
  137. return decisionnode(results=uniquecounts(rows))
  138.  
  139. def printtree(tree,indent=''):
  140. # Is this a leaf node?
  141. if tree.results!=None:
  142. print (str(tree.results))
  143. else:
  144. # Print the criteria
  145. print (str(tree.col)+':'+str(tree.value)+'? ')
  146. # Print the branches
  147. print (indent + 'T->')
  148.  
  149. printtree(tree.tb,lvl + 1,indent+' ')
  150. print (indent+'F->')
  151. printtree(tree.fb,lvl + 1 ,indent+' ')
  152.  
  153.  
  154. def classify(observation,tree):
  155. if tree.results!=None:
  156. return tree.results
  157. else:
  158. vrednost=observation[tree.col]
  159. branch=None
  160.  
  161. if isinstance(vrednost,int) or isinstance(vrednost,float):
  162. if vrednost>=tree.value: branch=tree.tb
  163. else: branch=tree.fb
  164. else:
  165. if vrednost==tree.value: branch=tree.tb
  166. else: branch=tree.fb
  167.  
  168. return classify(observation,branch)
  169.  
  170. if __name__ == "__main__":
  171. referrer='google'
  172. location='US'
  173. readFAQ='no'
  174. pagesVisited=20
  175. serviceChosen='Premium'
  176.  
  177.  
  178.  
  179. testCase=[referrer, location, readFAQ, pagesVisited, serviceChosen]
  180. trainingData.append(testCase)
  181. t=buildtree(trainingData)
  182. l1 = len(trainingData)//2
  183. l2 = len(trainingData)-1
  184. t1 = trainingData[:l1]
  185. t2 = trainingData[l1:]
  186.  
  187. tree1 = buildtree(t1)
  188. tree2 = buildtree(t2)
  189. ke1 = classify(testCase, tree1)
  190. ke2 = classify(testCase, tree2)
  191. lista = []
  192. lista2 =[]
  193. for val in ke1.values():
  194. lista.append(val)
  195. for val in ke2.values():
  196. lista2.append(val)
  197.  
  198. lista.sort()
  199. lista2.sort()
  200.  
  201. num1 = lista[len(lista)-1]
  202. num2 = lista2[len(lista2)-1]
  203.  
  204. lista_key = []
  205. lista_key2 = []
  206.  
  207. for key in ke1.keys():
  208. if ke1.get(key)==num1:
  209. lista_key.append(key)
  210. for key in ke2.keys():
  211. if ke2.get(key)==num2
  212. lista_key2.append(key)
  213. lista_key.sort()
  214. lista_key2.sort()
  215. print (lista_key[0] + " " + lista_key2[0])
  216. # if ke1[0] == ke2[0]:
  217. # print (ke1[0])
  218. # else:
  219. # print ("FALSE")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement