daily pastebin goal
28%
SHARE
TWEET

Untitled

a guest Dec 15th, 2017 145 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/bin/python
  2.  
  3. # Author: SeyyedHossein Hasanpour copyright 2017, license GPLv3.
  4. # Seyyed Hossein Hasan Pour:
  5. # Coderx7@Gmail.com
  6. # Changelog:
  7. # 2015:
  8. # initial code to calculate confusionmatrix by Axel Angel
  9. # 7/3/2016:(adding new features-by-hossein)
  10. # added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
  11. # 01/03/2017:
  12. # removed old codes and Added Recall/Precision/F1-Score as well
  13. # 03/05/2017
  14. # Added ConfusionMatrix which was mistakenly ommited before.
  15. #info:
  16. #if on windows, one can use these command in a batch file and ease him/her self
  17. #REM Calculating Confusing Matrix
  18. #python confusionMatrix_convnet_test.py --proto cifar10_deploy_94_68.prototxt --model cifar10_deploy_94_68.caffemodel --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb  
  19. #
  20.  
  21.  
  22. import sys
  23. import caffe
  24. import numpy as np
  25. import lmdb
  26. import argparse
  27. from collections import defaultdict
  28. from sklearn.metrics import classification_report
  29. from sklearn.metrics import confusion_matrix
  30. import matplotlib.pyplot as plt
  31. import itertools
  32.  
  33. def flat_shape(x):
  34.     "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
  35.     return x.reshape(filter(lambda s: s > 1, x.shape))
  36.  
  37. def db_reader(fpath, type='lmdb'):
  38.     if type == 'lmdb':
  39.         return lmdb_reader(fpath)
  40.     else:
  41.        return leveldb_reader(fpath)
  42.  
  43.  
  44. def plot_confusion_matrix(cm #confusion matrix
  45.                          ,classes
  46.                           ,normalize=False
  47.                           ,title='Confusion matrix'
  48.                           ,cmap=plt.cm.Blues):
  49.     """
  50.     This function prints and plots the confusion matrix.
  51.     Normalization can be applied by setting `normalize=True`.
  52.     """
  53.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
  54.     plt.title(title)
  55.     plt.colorbar()
  56.     tick_marks = np.arange(len(classes))
  57.     plt.xticks(tick_marks, classes, rotation=45)
  58.     plt.yticks(tick_marks, classes)
  59.  
  60.     if normalize:
  61.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  62.         print("confusion matrix is normalized!")
  63.    
  64.     #print(cm)
  65.  
  66.     thresh = cm.max() / 2.
  67.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  68.         plt.text(j, i, cm[i, j],
  69.                  horizontalalignment="center",
  70.                  color="white" if cm[i, j] > thresh else "black")
  71.  
  72.     plt.tight_layout()
  73.     plt.ylabel('True label')
  74.     plt.xlabel('Predicted label')
  75.  
  76.  
  77.  
  78.        
  79. def lmdb_reader(fpath):
  80.     import lmdb
  81.     lmdb_env = lmdb.open(fpath)
  82.     lmdb_txn = lmdb_env.begin()
  83.     lmdb_cursor = lmdb_txn.cursor()
  84.  
  85.     for key, value in lmdb_cursor:
  86.         datum = caffe.proto.caffe_pb2.Datum()
  87.         datum.ParseFromString(value)
  88.         label = int(datum.label)
  89.         image = caffe.io.datum_to_array(datum).astype(np.uint8)
  90.         yield (key, flat_shape(image), label)
  91.  
  92. def leveldb_reader(fpath):
  93.     import leveldb
  94.     db = leveldb.LevelDB(fpath)
  95.  
  96.     for key, value in db.RangeIter():
  97.         datum = caffe.proto.caffe_pb2.Datum()
  98.         datum.ParseFromString(value)
  99.         label = int(datum.label)
  100.         image = caffe.io.datum_to_array(datum).astype(np.uint8)
  101.         yield (key, flat_shape(image), label)
  102.  
  103.  
  104. if __name__ == "__main__":
  105.     parser = argparse.ArgumentParser()
  106.     parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
  107.     parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
  108.     parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
  109.     #group = parser.add_mutually_exclusive_group(required=True)
  110.     parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
  111.     parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
  112.     args = parser.parse_args()
  113.  
  114.    # Extract mean from the mean image file
  115.     mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
  116.     f = open(args.mean, 'rb')
  117.     mean_blobproto_new.ParseFromString(f.read())
  118.     mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
  119.     f.close()
  120.    
  121.    # CNN reconstruction and loading the trained weights
  122.     net = caffe.Net(args.proto, args.model, caffe.TEST)
  123.    # You may also use set_mode_cpu() if you didnt compile caffe with gpu support
  124.     caffe.set_mode_gpu()
  125.    
  126.     print ("args", vars(args))
  127.    
  128.     reader = db_reader(args.db_path, args.db_type.lower())
  129.    
  130.     predicted_lables=[]
  131.     true_labels = []
  132.     class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
  133.     for i, image, label in reader:
  134.         image_caffe = image.reshape(1, *image.shape)
  135.         #print 'image shape: ',image_caffe.shape
  136.         out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
  137.         plabel = int(out['prob'][0].argmax(axis=0))
  138.         predicted_lables.append(plabel)
  139.         true_labels.append(label)
  140.         print(i,' processed!')
  141.    
  142.     print( classification_report(y_true=true_labels,
  143.                                  y_pred=predicted_lables,
  144.                                  target_names=class_names))
  145.     cm = confusion_matrix(y_true=true_labels,
  146.                             y_pred=predicted_lables)   
  147.     print(cm)                          
  148.     # Compute confusion matrix
  149.     cnf_matrix = cm
  150.     np.set_printoptions(precision=2)
  151.  
  152.     # Plot non-normalized confusion matrix
  153.     plt.figure()
  154.     plot_confusion_matrix(cnf_matrix, classes=class_names,
  155.                           title='Confusion matrix, without normalization')
  156.  
  157.     # Plot normalized confusion matrix
  158.     plt.figure()
  159.     plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
  160.                           title='Normalized confusion matrix')
  161.  
  162.     plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top