Advertisement
Guest User

Untitled

a guest
Sep 21st, 2019
101
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.69 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3. import seaborn as sns
  4.  
  5. def plot_confusion_matrix(cm_train, cm_test, classes, cmap="YlGnBu"):
  6. plt.figure(1, figsize = (13,5))
  7. plt.subplot(121)
  8. df1 = pd.DataFrame(cm_train
  9. ,index=classes
  10. ,columns=classes)
  11. sns.heatmap(df1, annot=True, fmt='g', cmap=cmap)
  12. plt.title("Training data")
  13. plt.ylabel('True')
  14. plt.xlabel('Predicted')
  15.  
  16. plt.subplot(122)
  17. df1 = pd.DataFrame(cm_test
  18. ,index=classes
  19. ,columns=classes)
  20. sns.heatmap(df1, annot=True, fmt='g', cmap=cmap)
  21. plt.title("Test data")
  22. plt.ylabel('True')
  23. plt.xlabel('Predicted')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement