Advertisement
fevzi02

Untitled

Dec 26th, 2023
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.72 KB | None | 0 0
  1. # To combine both the decision tree visualization and the decision boundary plot in one code,
  2. # we will use the same dataset (iris dataset) and DecisionTreeClassifier for both.
  3.  
  4. from sklearn.datasets import load_iris
  5. from sklearn.model_selection import train_test_split
  6. from sklearn.tree import DecisionTreeClassifier, plot_tree
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9.  
  10. # Load the iris dataset
  11. iris = load_iris()
  12. X = iris.data
  13. y = iris.target
  14.  
  15. # We will only take two features for visualization purposes
  16. X = X[:, 2:]
  17.  
  18. # Split the dataset into training and testing sets
  19. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  20.  
  21. # Initialize the DecisionTreeClassifier
  22. clf = DecisionTreeClassifier(max_depth=2, random_state=42)
  23.  
  24. # Fit the model to the training data
  25. clf.fit(X_train, y_train)
  26.  
  27. # Create a grid for plotting decision boundaries
  28. x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
  29. y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
  30. xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
  31. np.arange(y_min, y_max, 0.01))
  32.  
  33. # Predict class for each point in the mesh grid
  34. Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
  35. Z = Z.reshape(xx.shape)
  36.  
  37. # Plot decision boundary
  38. plt.figure(figsize=(12, 6))
  39.  
  40. # Plot decision tree
  41. plt.subplot(1, 2, 1)
  42. plot_tree(clf, filled=True, feature_names=iris.feature_names[2:], class_names=iris.target_names)
  43. plt.title('Decision Tree')
  44.  
  45. # Plot decision boundaries
  46. plt.subplot(1, 2, 2)
  47. plt.contourf(xx, yy, Z, alpha=0.3)
  48. plt.scatter(X[:, 0], X[:, 1], c=y, s=50, edgecolor='k')
  49. plt.title("Decision Boundary")
  50. plt.xlabel('Petal length')
  51. plt.ylabel('Petal width')
  52.  
  53. # Adjust layout
  54. plt.tight_layout()
  55. plt.show()
  56.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement