Advertisement
Guest User

Untitled

a guest
Jul 26th, 2017
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.71 KB | None | 0 0
  1. #in the name of God the most compassionate the most merciful
  2. #Note, remember to change crop size in the code! its hardcoded (224x224!) I didnt bother to add a new argument for that, do it yourself!
  3. # added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
  4. # Seyyed Hossein Hasan Pour
  5. # Coderx7@Gmail.com
  6. # 7/3/2016
  7. # Added Recall/Precision/F1-Score as well
  8. # 01/03/2017
  9. #info:
  10. #if on windows, one can use these command in a batch file and ease him/her self
  11. #REM Calculating Confusing Matrix
  12. #python confusionMatrix_convnet_test.py --proto deploy.prototxt --model simpleNet.caffemodel. --mean mean_imagenet.binaryproto --lmdb mydata_test_lmdb
  13. #pause
  14.  
  15.  
  16. import sys
  17. import caffe
  18. import numpy as np
  19. import lmdb
  20. import argparse
  21. from collections import defaultdict
  22. from sklearn.metrics import classification_report
  23. from sklearn.metrics import confusion_matrix
  24. import matplotlib.pyplot as plt
  25. import itertools
  26. from sklearn.metrics import roc_curve, auc
  27. import random
  28.  
  29. def flat_shape(x):
  30. "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
  31. return np.reshape(x,x.shape)
  32.  
  33. def plot_confusion_matrix(cm #confusion matrix
  34. ,classes
  35. ,normalize=False
  36. ,title='Confusion matrix'
  37. ,cmap=plt.cm.Blues):
  38. """
  39. This function prints and plots the confusion matrix.
  40. Normalization can be applied by setting `normalize=True`.
  41. """
  42. plt.imshow(cm, interpolation='nearest', cmap=cmap)
  43. plt.title(title)
  44. plt.colorbar()
  45. tick_marks = np.arange(len(classes))
  46. plt.xticks(tick_marks, classes, rotation=45)
  47. plt.yticks(tick_marks, classes)
  48.  
  49. if normalize:
  50. cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  51. print("confusion matrix is normalized!")
  52.  
  53. #print(cm)
  54.  
  55. thresh = cm.max() / 2.
  56. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  57. plt.text(j, i, cm[i, j],
  58. horizontalalignment="center",
  59. color="white" if cm[i, j] > thresh else "black")
  60.  
  61. plt.tight_layout()
  62. plt.ylabel('True label')
  63. plt.xlabel('Predicted label')
  64.  
  65.  
  66. def db_reader(fpath, type='lmdb'):
  67. if type == 'lmdb':
  68. return lmdb_reader(fpath)
  69. else:
  70. return leveldb_reader(fpath)
  71.  
  72.  
  73. def lmdb_reader(fpath):
  74. import lmdb
  75. lmdb_env = lmdb.open(fpath)
  76. lmdb_txn = lmdb_env.begin()
  77. lmdb_cursor = lmdb_txn.cursor()
  78.  
  79. for key, value in lmdb_cursor:
  80. datum = caffe.proto.caffe_pb2.Datum()
  81. datum.ParseFromString(value)
  82. label = int(datum.label)
  83. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  84. yield (key, flat_shape(image), label)
  85.  
  86. def leveldb_reader(fpath):
  87. import leveldb
  88. db = leveldb.LevelDB(fpath)
  89.  
  90. for key, value in db.RangeIter():
  91. datum = caffe.proto.caffe_pb2.Datum()
  92. datum.ParseFromString(value)
  93. label = int(datum.label)
  94. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  95. yield (key, flat_shape(image), label)
  96.  
  97.  
  98. def ShowInfo(correct, count, true_labels, predicted_lables, class_names, misclassified,
  99. filename='misclassifieds.txt',
  100. title='Receiver Operating Characteristic_ROC',
  101. title_CM='Confusion matrix, without normalization',
  102. title_CM_N='Normalized confusion matrix'):
  103. sys.stdout.write("\rAccuracy: %.2f%%" % (100.*correct/count))
  104. sys.stdout.flush()
  105. print(", %i/%i corrects" % (correct, count))
  106.  
  107. np.savetxt(filename,misclassified,fmt="%s")
  108. print( classification_report(y_true=true_labels,
  109. y_pred=predicted_lables,
  110. target_names=class_names))
  111.  
  112. cm = confusion_matrix(y_true=true_labels,
  113. y_pred=predicted_lables)
  114. print(cm)
  115.  
  116. print(title)
  117. false_positive_rate, true_positive_rate, thresholds = roc_curve(true_labels, predicted_lables)
  118. roc_auc = auc(false_positive_rate, true_positive_rate)
  119. plt.title('Receiver Operating Characteristic_ROC 1')
  120. plt.plot(false_positive_rate, true_positive_rate, 'b',
  121. label='AUC = %0.2f'% roc_auc)
  122. plt.legend(loc='lower right')
  123. plt.plot([0,1],[0,1],'r--')
  124. plt.xlim([-0.1,1.2])
  125. plt.ylim([-0.1,1.2])
  126. plt.ylabel('True Positive Rate')
  127. plt.xlabel('False Positive Rate')
  128. plt.show()
  129.  
  130. # Compute confusion matrix
  131. cnf_matrix = cm
  132. np.set_printoptions(precision=2)
  133. # Plot non-normalized confusion matrix
  134. plt.figure()
  135. plot_confusion_matrix(cnf_matrix, classes=class_names,
  136. title=title_CM)
  137. # Plot normalized confusion matrix
  138. plt.figure()
  139. plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
  140. title=title_CM_N)
  141. plt.show()
  142.  
  143.  
  144.  
  145. if __name__ == "__main__":
  146. parser = argparse.ArgumentParser()
  147. parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
  148. parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
  149. parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
  150. #group = parser.add_mutually_exclusive_group(required=True)
  151. parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
  152. parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
  153. args = parser.parse_args()
  154.  
  155. # Extract mean from the mean image file
  156. mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
  157. f = open(args.mean, 'rb')
  158. mean_blobproto_new.ParseFromString(f.read())
  159. mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
  160. f.close()
  161.  
  162.  
  163. caffe.set_mode_gpu()
  164. #CNN reconstruction and loading the trained weights
  165. #print ("args", vars(args))
  166.  
  167.  
  168.  
  169. predicted_lables=[]
  170. true_labels = []
  171. misclassified =[]
  172. class_names = ['unsafe','safe']
  173. count=0
  174. correct = 0
  175. idx=0
  176. batch=[]
  177. plabe_ls=[]
  178. batch_size = 50
  179. cropx = 224
  180. cropy = 224
  181. net1 = caffe.Net(args.proto, args.model, caffe.TEST)
  182. # transformer = caffe.io.Transformer({'data': net1.blobs['data'].data.shape})
  183. # transformer.set_transpose('data', (2,0,1))
  184. # transformer.set_mean('data', mean_image[0].mean(1).mean(1))
  185. # transformer.set_raw_scale('data', 255)
  186. # transformer.set_channel_swap('data', (2,1,0))
  187. net1.blobs['data'].reshape(batch_size, 3,224, 224)
  188. data_blob_shape = net1.blobs['data'].data.shape
  189. data_blob_shape = list(data_blob_shape)
  190. #net1.blobs['data'].reshape(batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3])
  191. i=0
  192.  
  193. #mu = np.load('mean.npy')
  194. mu = np.array([ 104, 117, 123])#imagenet mean
  195.  
  196. #reader = db_reader(args.db_path, args.db_type.lower())
  197. #check and see if its lmdb or leveldb
  198. if(args.db_type.lower() == 'lmdb'):
  199. lmdb_env = lmdb.open(args.db_path)
  200. lmdb_txn = lmdb_env.begin()
  201. lmdb_cursor = lmdb_txn.cursor()
  202. for key, value in lmdb_cursor:
  203. count += 1
  204. datum = caffe.proto.caffe_pb2.Datum()
  205. datum.ParseFromString(value)
  206. label = int(datum.label)
  207. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  208. #key,image,label
  209. #buffer n image
  210. if(count%5000==0):
  211. print('{0} samples processed so far'.format(count))
  212. if(i < batch_size):
  213. i+=1
  214. inf= key,image,label
  215. batch.append(inf)
  216. #print(key)
  217. if(i >= batch_size):
  218. #process n image
  219. ims=[]
  220. for x in range(len(batch)):
  221. img = batch[x][1]
  222. #img has c,h,w shape! its already gone through transpose and channel swap when it was being saved into lmdb!
  223. #method I: crop the both the image and mean file
  224. #ims.append(img[:,0:224,0:224] - mean_image[0][:,0:224,0:224] )
  225. #Method II : resize the image to the desired size(crop size)
  226. #img = caffe.io.resize_image(img.transpose(2,1,0), (224, 224))
  227. #Method III : use center crop just like caffe does in test time
  228. #center crop
  229. c,h,w = img.shape
  230. startx = h//2 - cropx//2
  231. starty = w//2 - cropy//2
  232. img = img[:, startx:startx + cropx, starty:starty + cropy]
  233. #transpose the image so we can subtract from mean, and in the meanwhile, change it to float!
  234. img = np.array(img.transpose(2,1,0),dtype=np.float32)
  235. img -= mean_image[0].mean(1).mean(1)
  236. #transpose back to the original state
  237. img = img.transpose(2,1,0)
  238. ims.append(img)
  239.  
  240. net1.blobs['data'].data[...] = ims[:]
  241. out_1 = net1.forward()
  242. #print('batch processed')
  243. plabe_ls = out_1['pred']#.argmax(axis=0)
  244. plbl = np.asarray(plabe_ls)
  245. #print(plbl)
  246. #print(plbl.argmax(axis=1))
  247. plbl = plbl.argmax(axis=1)
  248. for j in range(len(batch)):
  249. if (plbl[j] == batch[j][2]):
  250. correct+=1
  251. else:
  252. misclassified.append(batch[j][0])
  253.  
  254. predicted_lables.append(plbl[j])
  255. true_labels.append(batch[j][2])
  256. batch.clear()
  257. i=0
  258.  
  259.  
  260. ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified,
  261. filename='misclassifieds.txt',
  262. title='Receiver Operating Characteristic_ROC' )
  263.  
  264. else:#leveldb
  265. import leveldb
  266. db = leveldb.LevelDB(args.db_path)
  267. for key, value in db.RangeIter():
  268. datum = caffe.proto.caffe_pb2.Datum()
  269. datum.ParseFromString(value)
  270. label = int(datum.label)
  271. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  272. #key,image,label
  273. #buffer n image
  274. #print('count: ',count)
  275. if(i < 50):
  276. i+=1
  277. inf= key,image,label
  278. batch.append(inf)
  279. #print(key)
  280. if(i >= 50):
  281. #process n image
  282. ims=[]
  283. for x in range(len(batch)):
  284. img = batch[x][1]
  285. #img has c,h,w shape! its already gone through transpose and channel swap when it was being saved into lmdb!
  286. #method I: crop the both the image and mean file
  287. #ims.append(img[:,0:224,0:224] - mean_image[0][:,0:224,0:224] )
  288. #Method II : resize the image to the desired size(crop size)
  289. #img = caffe.io.resize_image(img.transpose(2,1,0), (224, 224))
  290. #Method III : use center crop just like caffe does in test time
  291. #center crop
  292. c,h,w = img.shape
  293. startx = h//2 - cropx//2
  294. starty = w//2 - cropy//2
  295. img = img[:, startx:startx + cropx, starty:starty + cropy]
  296. #transpose the image so we can subtract from mean, and in the meanwhile, change it to float!
  297. img = np.array(img.transpose(2,1,0),dtype=np.float32)
  298. img -= mean_image[0].mean(1).mean(1)
  299. #transpose back to the original state
  300. img = img.transpose(2,1,0)
  301. ims.append(img)
  302.  
  303. net1.blobs['data'].data[...] = ims[:]
  304. out_1 = net1.forward()
  305. plabe_ls = out_1['pred']#.argmax(axis=0)
  306. plbl = np.asarray(plabe_ls)
  307. #print(plbl)
  308. #print(plbl.argmax(axis=1))
  309. plbl = plbl.argmax(axis=1)
  310. for j in range(len(batch)):
  311. if (plbl[j] == batch[j][2]):
  312. correct+=1
  313. else:
  314. misclassified.append(batch[j][0])
  315.  
  316. predicted_lables.append(plbl[j])
  317. true_labels.append(batch[j][2])
  318. batch.clear()
  319. i=0
  320.  
  321. ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified,
  322. filename='misclassifieds.txt',
  323. title='Receiver Operating Characteristic_ROC' )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement