SHARE
TWEET

Untitled

a guest Jun 18th, 2019 57 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #count the label in the tree to find the most occurence, which will be our prediction result
  2. def count_values(rows):
  3.     count={}
  4.     for row in  rows:
  5.         label=row[-1]
  6.         if label not in count:
  7.             count[label]=0
  8.         count[label]+=1
  9.     return count
  10.  
  11. #our decision node
  12. class DecisionNode:
  13.     def __init__(self, splits, feature, value, score):
  14.         self.feature = feature
  15.         self.value = value
  16.         self.true_branch = splits[0]
  17.         self.false_branch = splits[1]
  18.         self.score = score
  19.    
  20. #leaf object of our decison tree    
  21. class Leaf:
  22.     def __init__(self,rows):
  23.         self.predictions = count_values(rows)
  24.  
  25. #function needed to build the decision tree
  26. def build_tree(dataset):
  27.     split_info = best_split(dataset)
  28.  
  29.     if split_info['score'] == 0:       
  30.         return DecisionNode([Leaf(split_info['splits'][0]), Leaf(split_info['splits'][1])], split_info['feature'], split_info['value'], split_info['score'])
  31.     true_branch = build_tree(split_info['splits'][0])
  32.     false_branch = build_tree(split_info['splits'][1])
  33.     return DecisionNode([true_branch, false_branch], split_info['feature'], split_info['value'], split_info['score'])
  34.  
  35. # function needed to visualize our decision tree
  36. def print_tree(node,indentation="   "):
  37.  
  38.     if isinstance(node,Leaf):
  39.         print(indentation+"PREDICTION",node.predictions)
  40.         print(indentation+"Class", class_name[max(node.predictions.items(), key=operator.itemgetter(1))[0]] )
  41.         return
  42.      
  43.     print(indentation + 'Is %s <= %.3f?'% (iris.feature_names[node.feature], node.value))
  44.     print(indentation + 'GINI = ', node.score)
  45.    
  46.     print(indentation+ "True Branch")
  47.     print_tree(node.true_branch,indentation + "     ")
  48.    
  49.     print(indentation+ "False Branch")
  50.     print_tree(node.false_branch,indentation + "    ")
  51.  
  52. # to find the prediction
  53. def prediction(node, row):
  54.     if isinstance(node, Leaf):
  55.         return node.predictions
  56.     if row[node.feature] <= node.value:
  57.         return prediction(node.true_branch, row)
  58.     else:
  59.         return prediction(node.false_branch, row)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top