• API
• FAQ
• Tools
• Archive
A Pastebin account makes a great Christmas gift
SHARE
TWEET

# Untitled

a guest Dec 15th, 2017 145 Never
ENDING IN00days00hours00mins00secs

1. #!/usr/bin/python
2.
4. # Seyyed Hossein Hasan Pour:
5. # Coderx7@Gmail.com
6. # Changelog:
7. # 2015:
8. # initial code to calculate confusionmatrix by Axel Angel
10. # added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
11. # 01/03/2017:
12. # removed old codes and Added Recall/Precision/F1-Score as well
13. # 03/05/2017
14. # Added ConfusionMatrix which was mistakenly ommited before.
15. #info:
16. #if on windows, one can use these command in a batch file and ease him/her self
17. #REM Calculating Confusing Matrix
18. #python confusionMatrix_convnet_test.py --proto cifar10_deploy_94_68.prototxt --model cifar10_deploy_94_68.caffemodel --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb
19. #
20.
21.
22. import sys
23. import caffe
24. import numpy as np
25. import lmdb
26. import argparse
27. from collections import defaultdict
28. from sklearn.metrics import classification_report
29. from sklearn.metrics import confusion_matrix
30. import matplotlib.pyplot as plt
31. import itertools
32.
33. def flat_shape(x):
34.     "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
35.     return x.reshape(filter(lambda s: s > 1, x.shape))
36.
38.     if type == 'lmdb':
40.     else:
42.
43.
44. def plot_confusion_matrix(cm #confusion matrix
45.                          ,classes
46.                           ,normalize=False
47.                           ,title='Confusion matrix'
48.                           ,cmap=plt.cm.Blues):
49.     """
50.     This function prints and plots the confusion matrix.
51.     Normalization can be applied by setting `normalize=True`.
52.     """
53.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
54.     plt.title(title)
55.     plt.colorbar()
56.     tick_marks = np.arange(len(classes))
57.     plt.xticks(tick_marks, classes, rotation=45)
58.     plt.yticks(tick_marks, classes)
59.
60.     if normalize:
61.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
62.         print("confusion matrix is normalized!")
63.
64.     #print(cm)
65.
66.     thresh = cm.max() / 2.
67.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
68.         plt.text(j, i, cm[i, j],
69.                  horizontalalignment="center",
70.                  color="white" if cm[i, j] > thresh else "black")
71.
72.     plt.tight_layout()
73.     plt.ylabel('True label')
74.     plt.xlabel('Predicted label')
75.
76.
77.
78.
80.     import lmdb
81.     lmdb_env = lmdb.open(fpath)
82.     lmdb_txn = lmdb_env.begin()
83.     lmdb_cursor = lmdb_txn.cursor()
84.
85.     for key, value in lmdb_cursor:
86.         datum = caffe.proto.caffe_pb2.Datum()
87.         datum.ParseFromString(value)
88.         label = int(datum.label)
89.         image = caffe.io.datum_to_array(datum).astype(np.uint8)
90.         yield (key, flat_shape(image), label)
91.
93.     import leveldb
94.     db = leveldb.LevelDB(fpath)
95.
96.     for key, value in db.RangeIter():
97.         datum = caffe.proto.caffe_pb2.Datum()
98.         datum.ParseFromString(value)
99.         label = int(datum.label)
100.         image = caffe.io.datum_to_array(datum).astype(np.uint8)
101.         yield (key, flat_shape(image), label)
102.
103.
104. if __name__ == "__main__":
105.     parser = argparse.ArgumentParser()
106.     parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
108.     parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
110.     parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
112.     args = parser.parse_args()
113.
114.    # Extract mean from the mean image file
115.     mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
116.     f = open(args.mean, 'rb')
118.     mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
119.     f.close()
120.
122.     net = caffe.Net(args.proto, args.model, caffe.TEST)
123.    # You may also use set_mode_cpu() if you didnt compile caffe with gpu support
124.     caffe.set_mode_gpu()
125.
126.     print ("args", vars(args))
127.
129.
130.     predicted_lables=[]
131.     true_labels = []
132.     class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
133.     for i, image, label in reader:
134.         image_caffe = image.reshape(1, *image.shape)
135.         #print 'image shape: ',image_caffe.shape
136.         out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
137.         plabel = int(out['prob'][0].argmax(axis=0))
138.         predicted_lables.append(plabel)
139.         true_labels.append(label)
140.         print(i,' processed!')
141.
142.     print( classification_report(y_true=true_labels,
143.                                  y_pred=predicted_lables,
144.                                  target_names=class_names))
145.     cm = confusion_matrix(y_true=true_labels,
146.                             y_pred=predicted_lables)
147.     print(cm)
148.     # Compute confusion matrix
149.     cnf_matrix = cm
150.     np.set_printoptions(precision=2)
151.
152.     # Plot non-normalized confusion matrix
153.     plt.figure()
154.     plot_confusion_matrix(cnf_matrix, classes=class_names,
155.                           title='Confusion matrix, without normalization')
156.
157.     # Plot normalized confusion matrix
158.     plt.figure()
159.     plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
160.                           title='Normalized confusion matrix')
161.
162.     plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top