Advertisement
toweber

decision_tree_cost_complexity

Aug 15th, 2022
718
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.66 KB | None | 0 0
  1. from sklearn import tree
  2. from sklearn.model_selection import train_test_split
  3. import numpy as np
  4. import ipdb
  5.  
  6.  
  7. np.random.seed(0)
  8.  
  9. X = np.random.random([500,2])
  10. Y = []
  11. for x_value in X:
  12.     y_value = 0
  13.     if (x_value[0] > 0.5) and (x_value[1] > 0.5):  # AND "logic"
  14.         y_value = 1
  15.  
  16.     #random part
  17.     #if np.random.random() > 0.9:
  18.     if np.random.random() > 0.9:
  19.         y_value = not y_value
  20.     Y.append(y_value)
  21.  
  22. Y = np.array(Y)        
  23.  
  24. feature_names = ['x1','x2']
  25. target_names = ['falso','verdadeiro']
  26.  
  27.  
  28. X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 0)
  29. # criar a árvore inicial
  30. clf = tree.DecisionTreeClassifier()
  31.  
  32. # encontrar os elos fracos (valores de alfa onde as "mudanças ocorrem")
  33. path = clf.cost_complexity_pruning_path(X_train, Y_train)
  34. ccp_alphas, impurities = path.ccp_alphas, path.impurities
  35.  
  36. # criar uma árvore para cada valor de alfa
  37. clfs = []
  38. for ccp_alpha in ccp_alphas:
  39.     clf = tree.DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
  40.     clf.fit(X_train, Y_train)
  41.     clfs.append(clf)
  42.  
  43. train_scores = [clf.score(X_train, Y_train) for clf in clfs]
  44. test_scores = [clf.score(X_test, Y_test) for clf in clfs]
  45.  
  46. ipdb.set_trace()
  47.  
  48.  
  49. # escolher a melhor
  50.  
  51.  
  52. #.
  53. #.
  54. #.
  55.  
  56. clf = clf.fit(X_train, Y_train)
  57.  
  58. # export in graphical format
  59. import graphviz
  60. dot_data = tree.export_graphviz(clf, out_file=None, filled=False, rounded=True, impurity=True,
  61. class_names=target_names,
  62.                                 feature_names=feature_names
  63. )
  64.  
  65. graph = graphviz.Source(dot_data)
  66. graph.render("graph")
  67.  
  68. # export in text format
  69. r = tree.export_text(clf)
  70. print('\n'+r)
  71.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement