Advertisement
Guest User

Untitled

a guest
Dec 10th, 2019
139
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.31 KB | None | 0 0
  1.  
  2. from sklearn import metrics
  3. import pandas as pd
  4. from sklearn import tree
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.tree import DecisionTreeClassifier
  8. import pydotplus
  9. from IPython.display import Image
  10. from sklearn.externals.six import StringIO
  11. from sklearn.tree import export_graphviz
  12.  
  13. mushData=pd.read_csv('mushrooms.csv')
  14. X=pd.get_dummies(mushData.iloc[:,1:22]) #need to get dumies as
  15. y=mushData['class']
  16. #need confusion matrix, training accurady, test accuracy, class tree,
  17. #top 3 features, classify mushy
  18. X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.30,random_state=1)
  19. dt=DecisionTreeClassifier(max_depth=6,random_state=12)
  20.  
  21.  
  22. dt.fit(X_train,y_train)
  23. y_pred=dt.predict(X_test)
  24.  
  25. confusionmatrix=metrics.confusion_matrix(y_test,y_pred)
  26. print(confusionmatrix)
  27.  
  28. print(y_pred)
  29. print(y_test)
  30. print(y_train)
  31. acctest=accuracy_score(y_pred,y_test)
  32. print("Test classification accuracy: ",acctest)
  33.  
  34. y_predtrain=dt.predict(X_train)
  35. acctrain=accuracy_score(y_predtrain,y_train)
  36. print('Train classification accuracy: ',acctrain)
  37.  
  38.  
  39. dectree=tree.export_graphviz(dt)
  40. dot_data=export_graphviz(dt,out_file='treeq3.pdf',filled=True,rounded=True)
  41. pydot_graph=pydotplus.graph_from_dot_data(dectree)
  42. pydot_graph.write_pdf("treeq3.pdf")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement