Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sklearn import metrics
- import pandas as pd
- from sklearn import tree
- from sklearn.metrics import accuracy_score
- from sklearn.model_selection import train_test_split
- from sklearn.tree import DecisionTreeClassifier
- import pydotplus
- from IPython.display import Image
- from sklearn.externals.six import StringIO
- from sklearn.tree import export_graphviz
- mushData=pd.read_csv('mushrooms.csv')
- X=pd.get_dummies(mushData.iloc[:,1:22]) #need to get dumies as
- y=mushData['class']
- #need confusion matrix, training accurady, test accuracy, class tree,
- #top 3 features, classify mushy
- X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.30,random_state=1)
- dt=DecisionTreeClassifier(max_depth=6,random_state=12)
- dt.fit(X_train,y_train)
- y_pred=dt.predict(X_test)
- confusionmatrix=metrics.confusion_matrix(y_test,y_pred)
- print(confusionmatrix)
- print(y_pred)
- print(y_test)
- print(y_train)
- acctest=accuracy_score(y_pred,y_test)
- print("Test classification accuracy: ",acctest)
- y_predtrain=dt.predict(X_train)
- acctrain=accuracy_score(y_predtrain,y_train)
- print('Train classification accuracy: ',acctrain)
- dectree=tree.export_graphviz(dt)
- dot_data=export_graphviz(dt,out_file='treeq3.pdf',filled=True,rounded=True)
- pydot_graph=pydotplus.graph_from_dot_data(dectree)
- pydot_graph.write_pdf("treeq3.pdf")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement