Advertisement
Guest User

Untitled

a guest
Oct 15th, 2019
122
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.13 KB | None | 0 0
  1. #THIS IS THE CONFUSION MATRIX FUNCTION
  2.  
  3. def plot_confusion_matrix(y_true, y_pred, classes,
  4.                           normalize=False,
  5.                           title=None,
  6.                           cmap=plt.cm.Blues):
  7.     """
  8.    This function prints and plots the confusion matrix.
  9.    Normalization can be applied by setting `normalize=True`.
  10.    """
  11.     if not title:
  12.         if normalize:
  13.             title = 'Normalized confusion matrix'
  14.         else:
  15.             title = 'Confusion matrix, without normalization'
  16.  
  17.     # Compute confusion matrix
  18.     cm = confusion_matrix(y_true, y_pred)
  19.     # Only use the labels that appear in the data
  20.     #classes = classes[unique_labels(y_true, y_pred)]
  21.     if normalize:
  22.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  23.         print("Normalized confusion matrix")
  24.     else:
  25.         print('Confusion matrix, without normalization')
  26.  
  27.     print(cm)
  28.  
  29.     fig, ax = plt.subplots()
  30.     im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
  31.     ax.figure.colorbar(im, ax=ax)
  32.     # We want to show all ticks...
  33.     ax.set(xticks=np.arange(cm.shape[1]),
  34.            yticks=np.arange(cm.shape[0]),
  35.            # ... and label them with the respective list entries
  36.            xticklabels=classes, yticklabels=classes,
  37.            title=title,
  38.            ylabel='True label',
  39.            xlabel='Predicted label')
  40.  
  41.     # Rotate the tick labels and set their alignment.
  42.     plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
  43.              rotation_mode="anchor")
  44.  
  45.     # Loop over data dimensions and create text annotations.
  46.     fmt = '.2f' if normalize else 'd'
  47.     thresh = cm.max() / 2.
  48.     for i in range(cm.shape[0]):
  49.         for j in range(cm.shape[1]):
  50.             ax.text(j, i, format(cm[i, j], fmt),
  51.                     ha="center", va="center",
  52.                     color="white" if cm[i, j] > thresh else "black")
  53.     return ax
  54.  
  55.  
  56. from sklearn.metrics import f1_score
  57. from sklearn.metrics import balanced_accuracy_score
  58. from sklearn.metrics import confusion_matrix
  59.  
  60. #1.             predict labels
  61. Y_true = Y_test
  62. Y_pred_test = knn_clf.predict(X_test)
  63. Y_pred_test_pitch = knn_clf_pitch.predict(X_test)
  64. Y_pred_test_time = knn_clf_time.predict(X_test)
  65. Y_pred_test_reverb = knn_clf_reverb.predict(X_test)
  66. Y_pred_test_all = knn_clf_all.predict(X_test)
  67.  
  68.  
  69. #2.             compute balanced accuracy score
  70.  
  71. score_train = balanced_accuracy_score(Y_true, Y_pred_test)
  72. score_pitch = balanced_accuracy_score(Y_true, Y_pred_test_pitch)
  73. score_time = balanced_accuracy_score(Y_true, Y_pred_test_time)
  74. score_reverb = balanced_accuracy_score(Y_true, Y_pred_test_reverb)
  75. score_all = balanced_accuracy_score(Y_true, Y_pred_test_all)
  76.  
  77. #confusion matrices
  78. from sklearn.utils.multiclass import unique_labels
  79.  
  80. #CONFUSION MATRIX FOR 1) the original training data
  81. Y_pred = Y_pred_test
  82.  
  83. #              compute cm
  84. cm = confusion_matrix(Y_true, Y_pred)
  85.  
  86. #              plot cm
  87. np.set_printoptions(precision=2)
  88. # Plot normalized confusion matrix
  89. plot_confusion_matrix(Y_test, Y_pred, classes=instruments, normalize=True,
  90.                       title='Original Training Data')
  91. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement