Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import matplotlib.pyplot as plt
- import pandas as pd
- import seaborn as sns
- def plot_confusion_matrix(cm_train, cm_test, classes, cmap="YlGnBu"):
- plt.figure(1, figsize = (13,5))
- plt.subplot(121)
- df1 = pd.DataFrame(cm_train
- ,index=classes
- ,columns=classes)
- sns.heatmap(df1, annot=True, fmt='g', cmap=cmap)
- plt.title("Training data")
- plt.ylabel('True')
- plt.xlabel('Predicted')
- plt.subplot(122)
- df1 = pd.DataFrame(cm_test
- ,index=classes
- ,columns=classes)
- sns.heatmap(df1, annot=True, fmt='g', cmap=cmap)
- plt.title("Test data")
- plt.ylabel('True')
- plt.xlabel('Predicted')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement