Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import pandas as pd
- from sklearn.tree import DecisionTreeClassifier
- my_data = pd.read_csv("")
- my_data.head()
- #we give X the values to train
- X = my_data[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values
- X[0:5]
- #Change the values of categorical data with dummies
- from sklearn import preprocessing
- le_sex = preprocessing.LabelEncoder()
- le_sex.fit(['F','M'])
- X[:,1] = le_sex.transform(X[:,1])
- le_BP = preprocessing.LabelEncoder()
- le_BP.fit([ 'LOW', 'NORMAL', 'HIGH'])
- X[:,2] = le_BP.transform(X[:,2])
- le_Chol = preprocessing.LabelEncoder()
- le_Chol.fit([ 'NORMAL', 'HIGH'])
- X[:,3] = le_Chol.transform(X[:,3])
- X[0:5]
- y = my_data["Drug"]
- y[0:5]
- #We use Train test split
- from sklearn.model_selection import train_test_split
- X_trainset, X_testset, y_trainset, y_testset = train_test_split(X, y, test_size=0.3, random_state=3)
- #We create our model Decision Tree
- drugTree = DecisionTreeClassifier(criterion="entropy", max_depth = 4)
- drugTree # it shows the default parameters
- #We fit the Model Normally
- drugTree.fit(X_trainset,y_trainset)
- #We use prediction to predict the output
- predTree = drugTree.predict(X_testset)
- #print the prediction vs the real values
- print (predTree [0:5])
- print (y_testset [0:5])
- #Check the accuracy
- from sklearn import metrics
- import matplotlib.pyplot as plt
- print("DecisionTrees's Accuracy: ", metrics.accuracy_score(y_testset, predTree))
- #Tree visualization
- from sklearn.externals.six import StringIO
- import pydotplus
- import matplotlib.image as mpimg
- from sklearn import tree
- %matplotlib inline
- dot_data = StringIO()
- filename = "drugtree.png"
- featureNames = my_data.columns[0:5]
- targetNames = my_data["Drug"].unique().tolist()
- out=tree.export_graphviz(drugTree,feature_names=featureNames, out_file=dot_data, class_names= np.unique(y_trainset), filled=True, special_characters=True,rotate=False)
- graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
- graph.write_png(filename)
- img = mpimg.imread(filename)
- plt.figure(figsize=(100, 200))
- plt.imshow(img,interpolation='nearest')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement