Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class RandomForest():
- def __init__(self, x, y, n_trees, sample_sz, min_leaf=5, depth = 10):
- np.random.seed(42)
- self.x, self.y, self.sample_sz, self.min_leaf, self.depth = x, y, sample_sz, min_leaf, depth
- self.trees = [self.create_tree() for i in range(n_trees)]
- def create_tree(self):
- rnd_idxs = np.random.permutation(len(self.y))[:self.sample_sz] #bagging
- return DecisionTree(self.x.iloc[rnd_idxs], self.y[rnd_idxs], min_leaf=self.min_leaf, depth = 10)
- def predict(self, x):
- return np.mean([t.predict(x) for t in self.trees], axis=0)
Add Comment
Please, Sign In to add comment