Advertisement
Guest User

Untitled

a guest
Sep 23rd, 2019
155
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.03 KB | None | 0 0
  1. def plot_confusion_matrix(y_true, y_pred, classes,
  2. normalize=False,
  3. title=None,
  4. cmap=plt.cm.Blues):
  5. """
  6. This function prints and plots the confusion matrix.
  7. Normalization can be applied by setting `normalize=True`.
  8. """
  9. if not title:
  10. if normalize:
  11. title = 'Normalized confusion matrix'
  12. else:
  13. title = 'Confusion matrix, without normalization'
  14.  
  15. # Compute confusion matrix
  16. cm = confusion_matrix(y_true, y_pred)
  17. if normalize:
  18. cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  19. print("Normalized confusion matrix")
  20. else:
  21. print('Confusion matrix, without normalization')
  22.  
  23. fig, ax = plt.subplots()
  24. fig.set_figheight(15)
  25. fig.set_figwidth(15)
  26. im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
  27. ax.figure.colorbar(im, ax=ax)
  28. # We want to show all ticks...
  29. ax.set(xticks=np.arange(cm.shape[1]),
  30. yticks=np.arange(cm.shape[0]),
  31. # ... and label them with the respective list entries
  32. xticklabels=classes, yticklabels=classes,
  33. title=title,
  34. ylabel='True label',
  35. xlabel='Predicted label')
  36.  
  37. # Rotate the tick labels and set their alignment.
  38. plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
  39. rotation_mode="anchor")
  40.  
  41. # Loop over data dimensions and create text annotations.
  42. fmt = '.2f' if normalize else 'd'
  43. thresh = cm.max() / 2.
  44. for i in range(cm.shape[0]):
  45. for j in range(cm.shape[1]):
  46. ax.text(j, i, format(cm[i, j], fmt),
  47. ha="center", va="center",
  48. color="white" if cm[i, j] > thresh else "black")
  49. fig.tight_layout()
  50. return ax
  51.  
  52.  
  53. np.set_printoptions(precision=2)
  54.  
  55.  
  56. # Plot non-normalized confusion matrix
  57. plot_confusion_matrix(y_true, y_pred, classes=target_names,
  58. title='Confusion matrix, without normalization')
  59.  
  60. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement