Guest User

Untitled

a guest
Dec 15th, 2017
185
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.56 KB | None | 0 0
  1. #!/usr/bin/python
  2.  
  3. # Author: SeyyedHossein Hasanpour copyright 2017, license GPLv3.
  4. # Seyyed Hossein Hasan Pour:
  5. # Coderx7@Gmail.com
  6. # Changelog:
  7. # 2015:
  8. # initial code to calculate confusionmatrix by Axel Angel
  9. # 7/3/2016:(adding new features-by-hossein)
  10. # added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
  11. # 01/03/2017:
  12. # removed old codes and Added Recall/Precision/F1-Score as well
  13. # 03/05/2017
  14. # Added ConfusionMatrix which was mistakenly ommited before.
  15. #info:
  16. #if on windows, one can use these command in a batch file and ease him/her self
  17. #REM Calculating Confusing Matrix
  18. #python confusionMatrix_convnet_test.py --proto cifar10_deploy_94_68.prototxt --model cifar10_deploy_94_68.caffemodel --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb
  19. #
  20.  
  21.  
  22. import sys
  23. import caffe
  24. import numpy as np
  25. import lmdb
  26. import argparse
  27. from collections import defaultdict
  28. from sklearn.metrics import classification_report
  29. from sklearn.metrics import confusion_matrix
  30. import matplotlib.pyplot as plt
  31. import itertools
  32.  
  33. def flat_shape(x):
  34. "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
  35. return x.reshape(filter(lambda s: s > 1, x.shape))
  36.  
  37. def db_reader(fpath, type='lmdb'):
  38. if type == 'lmdb':
  39. return lmdb_reader(fpath)
  40. else:
  41. return leveldb_reader(fpath)
  42.  
  43.  
  44. def plot_confusion_matrix(cm #confusion matrix
  45. ,classes
  46. ,normalize=False
  47. ,title='Confusion matrix'
  48. ,cmap=plt.cm.Blues):
  49. """
  50. This function prints and plots the confusion matrix.
  51. Normalization can be applied by setting `normalize=True`.
  52. """
  53. plt.imshow(cm, interpolation='nearest', cmap=cmap)
  54. plt.title(title)
  55. plt.colorbar()
  56. tick_marks = np.arange(len(classes))
  57. plt.xticks(tick_marks, classes, rotation=45)
  58. plt.yticks(tick_marks, classes)
  59.  
  60. if normalize:
  61. cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  62. print("confusion matrix is normalized!")
  63.  
  64. #print(cm)
  65.  
  66. thresh = cm.max() / 2.
  67. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  68. plt.text(j, i, cm[i, j],
  69. horizontalalignment="center",
  70. color="white" if cm[i, j] > thresh else "black")
  71.  
  72. plt.tight_layout()
  73. plt.ylabel('True label')
  74. plt.xlabel('Predicted label')
  75.  
  76.  
  77.  
  78.  
  79. def lmdb_reader(fpath):
  80. import lmdb
  81. lmdb_env = lmdb.open(fpath)
  82. lmdb_txn = lmdb_env.begin()
  83. lmdb_cursor = lmdb_txn.cursor()
  84.  
  85. for key, value in lmdb_cursor:
  86. datum = caffe.proto.caffe_pb2.Datum()
  87. datum.ParseFromString(value)
  88. label = int(datum.label)
  89. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  90. yield (key, flat_shape(image), label)
  91.  
  92. def leveldb_reader(fpath):
  93. import leveldb
  94. db = leveldb.LevelDB(fpath)
  95.  
  96. for key, value in db.RangeIter():
  97. datum = caffe.proto.caffe_pb2.Datum()
  98. datum.ParseFromString(value)
  99. label = int(datum.label)
  100. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  101. yield (key, flat_shape(image), label)
  102.  
  103.  
  104. if __name__ == "__main__":
  105. parser = argparse.ArgumentParser()
  106. parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
  107. parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
  108. parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
  109. #group = parser.add_mutually_exclusive_group(required=True)
  110. parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
  111. parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
  112. args = parser.parse_args()
  113.  
  114. # Extract mean from the mean image file
  115. mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
  116. f = open(args.mean, 'rb')
  117. mean_blobproto_new.ParseFromString(f.read())
  118. mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
  119. f.close()
  120.  
  121. # CNN reconstruction and loading the trained weights
  122. net = caffe.Net(args.proto, args.model, caffe.TEST)
  123. # You may also use set_mode_cpu() if you didnt compile caffe with gpu support
  124. caffe.set_mode_gpu()
  125.  
  126. print ("args", vars(args))
  127.  
  128. reader = db_reader(args.db_path, args.db_type.lower())
  129.  
  130. predicted_lables=[]
  131. true_labels = []
  132. class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
  133. for i, image, label in reader:
  134. image_caffe = image.reshape(1, *image.shape)
  135. #print 'image shape: ',image_caffe.shape
  136. out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
  137. plabel = int(out['prob'][0].argmax(axis=0))
  138. predicted_lables.append(plabel)
  139. true_labels.append(label)
  140. print(i,' processed!')
  141.  
  142. print( classification_report(y_true=true_labels,
  143. y_pred=predicted_lables,
  144. target_names=class_names))
  145. cm = confusion_matrix(y_true=true_labels,
  146. y_pred=predicted_lables)
  147. print(cm)
  148. # Compute confusion matrix
  149. cnf_matrix = cm
  150. np.set_printoptions(precision=2)
  151.  
  152. # Plot non-normalized confusion matrix
  153. plt.figure()
  154. plot_confusion_matrix(cnf_matrix, classes=class_names,
  155. title='Confusion matrix, without normalization')
  156.  
  157. # Plot normalized confusion matrix
  158. plt.figure()
  159. plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
  160. title='Normalized confusion matrix')
  161.  
  162. plt.show()
Add Comment
Please, Sign In to add comment