daily pastebin goal
55%
SHARE
TWEET

Untitled

a guest Dec 15th, 2017 144 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/bin/python
  2.  
  3. # Author: SeyyedHossein Hasanpour copyright 2017, license GPLv3.
  4. # Seyyed Hossein Hasan Pour:
  5. # Coderx7@Gmail.com
  6. # Changelog:
  7. # 2015:
  8. # initial code to calculate confusionmatrix by Axel Angel
  9. # 7/3/2016:(adding new features-by-hossein)
  10. # added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
  11. # 01/03/2017:
  12. # removed old codes and Added Recall/Precision/F1-Score as well
  13. # 03/05/2017
  14. # Added ConfusionMatrix which was mistakenly ommited before.
  15. #info:
  16. #if on windows, one can use these command in a batch file and ease him/her self
  17. #REM Calculating Confusing Matrix
  18. #python confusionMatrix_convnet_test.py --proto cifar10_deploy_94_68.prototxt --model cifar10_deploy_94_68.caffemodel --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb  
  19. #
  20.  
  21.  
  22. import sys
  23. import caffe
  24. import numpy as np
  25. import lmdb
  26. import argparse
  27. from collections import defaultdict
  28. from sklearn.metrics import classification_report
  29. from sklearn.metrics import confusion_matrix
  30. import matplotlib.pyplot as plt
  31. import itertools
  32.  
  33. def flat_shape(x):
  34.     "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
  35.     return x.reshape(filter(lambda s: s > 1, x.shape))
  36.  
  37. def db_reader(fpath, type='lmdb'):
  38.     if type == 'lmdb':
  39.         return lmdb_reader(fpath)
  40.     else:
  41.        return leveldb_reader(fpath)
  42.  
  43.  
  44. def plot_confusion_matrix(cm #confusion matrix
  45.                          ,classes
  46.                           ,normalize=False
  47.                           ,title='Confusion matrix'
  48.                           ,cmap=plt.cm.Blues):
  49.     """
  50.     This function prints and plots the confusion matrix.
  51.     Normalization can be applied by setting `normalize=True`.
  52.     """
  53.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
  54.     plt.title(title)
  55.     plt.colorbar()
  56.     tick_marks = np.arange(len(classes))
  57.     plt.xticks(tick_marks, classes, rotation=45)
  58.     plt.yticks(tick_marks, classes)
  59.  
  60.     if normalize:
  61.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  62.         print("confusion matrix is normalized!")
  63.    
  64.     #print(cm)
  65.  
  66.     thresh = cm.max() / 2.
  67.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  68.         plt.text(j, i, cm[i, j],
  69.                  horizontalalignment="center",
  70.                  color="white" if cm[i, j] > thresh else "black")
  71.  
  72.     plt.tight_layout()
  73.     plt.ylabel('True label')
  74.     plt.xlabel('Predicted label')
  75.  
  76.  
  77.  
  78.        
  79. def lmdb_reader(fpath):
  80.     import lmdb
  81.     lmdb_env = lmdb.open(fpath)
  82.     lmdb_txn = lmdb_env.begin()
  83.     lmdb_cursor = lmdb_txn.cursor()
  84.  
  85.     for key, value in lmdb_cursor:
  86.         datum = caffe.proto.caffe_pb2.Datum()
  87.         datum.ParseFromString(value)
  88.         label = int(datum.label)
  89.         image = caffe.io.datum_to_array(datum).astype(np.uint8)
  90.         yield (key, flat_shape(image), label)
  91.  
  92. def leveldb_reader(fpath):
  93.     import leveldb
  94.     db = leveldb.LevelDB(fpath)
  95.  
  96.     for key, value in db.RangeIter():
  97.         datum = caffe.proto.caffe_pb2.Datum()
  98.         datum.ParseFromString(value)
  99.         label = int(datum.label)
  100.         image = caffe.io.datum_to_array(datum).astype(np.uint8)
  101.         yield (key, flat_shape(image), label)
  102.  
  103.  
  104. if __name__ == "__main__":
  105.     parser = argparse.ArgumentParser()
  106.     parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
  107.     parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
  108.     parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
  109.     #group = parser.add_mutually_exclusive_group(required=True)
  110.     parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
  111.     parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
  112.     args = parser.parse_args()
  113.  
  114.    # Extract mean from the mean image file
  115.     mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
  116.     f = open(args.mean, 'rb')
  117.     mean_blobproto_new.ParseFromString(f.read())
  118.     mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
  119.     f.close()
  120.    
  121.    # CNN reconstruction and loading the trained weights
  122.     net = caffe.Net(args.proto, args.model, caffe.TEST)
  123.    # You may also use set_mode_cpu() if you didnt compile caffe with gpu support
  124.     caffe.set_mode_gpu()
  125.    
  126.     print ("args", vars(args))
  127.    
  128.     reader = db_reader(args.db_path, args.db_type.lower())
  129.    
  130.     predicted_lables=[]
  131.     true_labels = []
  132.     class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
  133.     for i, image, label in reader:
  134.         image_caffe = image.reshape(1, *image.shape)
  135.         #print 'image shape: ',image_caffe.shape
  136.         out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
  137.         plabel = int(out['prob'][0].argmax(axis=0))
  138.         predicted_lables.append(plabel)
  139.         true_labels.append(label)
  140.         print(i,' processed!')
  141.    
  142.     print( classification_report(y_true=true_labels,
  143.                                  y_pred=predicted_lables,
  144.                                  target_names=class_names))
  145.     cm = confusion_matrix(y_true=true_labels,
  146.                             y_pred=predicted_lables)   
  147.     print(cm)                          
  148.     # Compute confusion matrix
  149.     cnf_matrix = cm
  150.     np.set_printoptions(precision=2)
  151.  
  152.     # Plot non-normalized confusion matrix
  153.     plt.figure()
  154.     plot_confusion_matrix(cnf_matrix, classes=class_names,
  155.                           title='Confusion matrix, without normalization')
  156.  
  157.     # Plot normalized confusion matrix
  158.     plt.figure()
  159.     plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
  160.                           title='Normalized confusion matrix')
  161.  
  162.     plt.show()
RAW Paste Data
Top