Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2019
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.38 KB | None | 0 0
  1. from sklearn.tree import _tree
  2.  
  3. tree_template = '''
  4. def tree{i}(inputs):
  5.  
  6. tree_undefined = {tree_undefined}
  7.  
  8. features = {features}
  9. thresholds = {thresholds}
  10. children_left = {children_left}
  11. children_right = {children_right}
  12. values = {values}
  13.  
  14. node = 0
  15. while features[node] != tree_undefined:
  16. feat = features[node]
  17. threshold = thresholds[node]
  18. if inputs[feat] <= threshold:
  19. node = children_left[node]
  20. else:
  21. node = children_right[node]
  22.  
  23. output = values[node]
  24.  
  25. return output
  26. '''
  27.  
  28. template_footer = '''
  29.  
  30. def forest(inputs):
  31. return ({combined_trees}) / {n}
  32. '''
  33.  
  34. template_final = ''
  35.  
  36. n = len(rfr.estimators_)
  37. for i, model in enumerate(rfr.estimators_):
  38. template_final += tree_template.format(
  39. i=i,
  40. tree_undefined=_tree.TREE_UNDEFINED,
  41. features=repr(model.tree_.feature.tolist()),
  42. thresholds=repr(model.tree_.threshold.tolist()),
  43. children_left=repr(model.tree_.children_left.tolist()),
  44. children_right=repr(model.tree_.children_right.tolist()),
  45. values=repr([val[0][0] for val in model.tree_.value]),
  46. )
  47.  
  48. template_final += template_footer.format(
  49. combined_trees=' + '.join(['tree{i}(inputs)'.format(i=i) for i in range(n)]),
  50. n=float(n)
  51. )
  52.  
  53. # execute the constructed code to load the function `forest` in the environment
  54. exec(template_final)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement