Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pandas as pd
- import matplotlib.pyplot as plt
- import seaborn as sns
- from sklearn.metrics import confusion_matrix
- np.set_printoptions(suppress=True, precision=2)
- def get_confusion_matrix(dist, network, mode, data_split, normalize=True):
- GR = {'DNO': [2, 6, 8], 'DUC': [1, 3, 4, 5, 7]}
- results = pd.read_csv(
- f'8_classes_training/{dist}/{network}/csv_results/{mode}_{network}_{data_split}.csv')
- cm = confusion_matrix(results.true, results.pred)
- if normalize:
- cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
- # Convert to dataframe to move columns and rows
- cm = pd.DataFrame(cm, index=range(1,9), columns=range(1,9))
- cm = cm[GR['DNO']+GR['DUC']]
- cm = cm.loc[GR['DNO']+GR['DUC']]
- return cm
- def plot_confusion_matrix(cm, figsize=(5, 5), c1=1.5, c2=5.5):
- plt.figure(figsize=figsize)
- ax = plt.subplot()
- sns.heatmap(cm, annot=True, cmap=plt.cm.Blues,
- ax=ax, cbar=False, fmt='.1%')
- ax.set_ylabel('Prediction')
- ax.set_xlabel('Actual')
- ax.axvline(3, color="black", lw=4)
- ax.axhline(3, color="black", lw=4)
- plt.text(c1,-0.1, "DNO", horizontalalignment='center')
- plt.text(c2,-0.1, "DUC", horizontalalignment='center')
- plt.text(len(cm)+.2,c1-.1, "DNO", horizontalalignment='center', rotation="vertical")
- plt.text(len(cm)+.2,c2-.1, "DUC", horizontalalignment='center', rotation="vertical")
- plt.show()
- cm1 = get_confusion_matrix('dist_1','resnet50','A','test')
- cm2 = get_confusion_matrix('dist_2','resnet50','A','test')
- cm3 = get_confusion_matrix('dist_3','resnet50','A','test')
- plot_confusion_matrix(cm1)
- plot_confusion_matrix(cm2)
- plot_confusion_matrix(cm3)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement