Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #in the name of God the most compassionate the most merciful
- #Note, remember to change crop size in the code! its hardcoded (224x224!) I didnt bother to add a new argument for that, do it yourself!
- # added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
- # Seyyed Hossein Hasan Pour
- # Coderx7@Gmail.com
- # 7/3/2016
- # Added Recall/Precision/F1-Score as well
- # 01/03/2017
- #info:
- #if on windows, one can use these command in a batch file and ease him/her self
- #REM Calculating Confusing Matrix
- #python confusionMatrix_convnet_test.py --proto deploy.prototxt --model simpleNet.caffemodel. --mean mean_imagenet.binaryproto --lmdb mydata_test_lmdb
- #pause
- import sys
- import caffe
- import numpy as np
- import lmdb
- import argparse
- from collections import defaultdict
- from sklearn.metrics import classification_report
- from sklearn.metrics import confusion_matrix
- import matplotlib.pyplot as plt
- import itertools
- from sklearn.metrics import roc_curve, auc
- import random
- def flat_shape(x):
- "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
- return np.reshape(x,x.shape)
- def plot_confusion_matrix(cm #confusion matrix
- ,classes
- ,normalize=False
- ,title='Confusion matrix'
- ,cmap=plt.cm.Blues):
- """
- This function prints and plots the confusion matrix.
- Normalization can be applied by setting `normalize=True`.
- """
- plt.imshow(cm, interpolation='nearest', cmap=cmap)
- plt.title(title)
- plt.colorbar()
- tick_marks = np.arange(len(classes))
- plt.xticks(tick_marks, classes, rotation=45)
- plt.yticks(tick_marks, classes)
- if normalize:
- cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
- print("confusion matrix is normalized!")
- #print(cm)
- thresh = cm.max() / 2.
- for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
- plt.text(j, i, cm[i, j],
- horizontalalignment="center",
- color="white" if cm[i, j] > thresh else "black")
- plt.tight_layout()
- plt.ylabel('True label')
- plt.xlabel('Predicted label')
- def db_reader(fpath, type='lmdb'):
- if type == 'lmdb':
- return lmdb_reader(fpath)
- else:
- return leveldb_reader(fpath)
- def lmdb_reader(fpath):
- import lmdb
- lmdb_env = lmdb.open(fpath)
- lmdb_txn = lmdb_env.begin()
- lmdb_cursor = lmdb_txn.cursor()
- for key, value in lmdb_cursor:
- datum = caffe.proto.caffe_pb2.Datum()
- datum.ParseFromString(value)
- label = int(datum.label)
- image = caffe.io.datum_to_array(datum).astype(np.uint8)
- yield (key, flat_shape(image), label)
- def leveldb_reader(fpath):
- import leveldb
- db = leveldb.LevelDB(fpath)
- for key, value in db.RangeIter():
- datum = caffe.proto.caffe_pb2.Datum()
- datum.ParseFromString(value)
- label = int(datum.label)
- image = caffe.io.datum_to_array(datum).astype(np.uint8)
- yield (key, flat_shape(image), label)
- def ShowInfo(correct, count, true_labels, predicted_lables, class_names, misclassified,
- filename='misclassifieds.txt',
- title='Receiver Operating Characteristic_ROC',
- title_CM='Confusion matrix, without normalization',
- title_CM_N='Normalized confusion matrix'):
- sys.stdout.write("\rAccuracy: %.2f%%" % (100.*correct/count))
- sys.stdout.flush()
- print(", %i/%i corrects" % (correct, count))
- np.savetxt(filename,misclassified,fmt="%s")
- print( classification_report(y_true=true_labels,
- y_pred=predicted_lables,
- target_names=class_names))
- cm = confusion_matrix(y_true=true_labels,
- y_pred=predicted_lables)
- print(cm)
- print(title)
- false_positive_rate, true_positive_rate, thresholds = roc_curve(true_labels, predicted_lables)
- roc_auc = auc(false_positive_rate, true_positive_rate)
- plt.title('Receiver Operating Characteristic_ROC 1')
- plt.plot(false_positive_rate, true_positive_rate, 'b',
- label='AUC = %0.2f'% roc_auc)
- plt.legend(loc='lower right')
- plt.plot([0,1],[0,1],'r--')
- plt.xlim([-0.1,1.2])
- plt.ylim([-0.1,1.2])
- plt.ylabel('True Positive Rate')
- plt.xlabel('False Positive Rate')
- plt.show()
- # Compute confusion matrix
- cnf_matrix = cm
- np.set_printoptions(precision=2)
- # Plot non-normalized confusion matrix
- plt.figure()
- plot_confusion_matrix(cnf_matrix, classes=class_names,
- title=title_CM)
- # Plot normalized confusion matrix
- plt.figure()
- plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
- title=title_CM_N)
- plt.show()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
- parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
- parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
- #group = parser.add_mutually_exclusive_group(required=True)
- parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
- parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
- args = parser.parse_args()
- # Extract mean from the mean image file
- mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
- f = open(args.mean, 'rb')
- mean_blobproto_new.ParseFromString(f.read())
- mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
- f.close()
- caffe.set_mode_gpu()
- #CNN reconstruction and loading the trained weights
- #print ("args", vars(args))
- predicted_lables=[]
- true_labels = []
- misclassified =[]
- class_names = ['unsafe','safe']
- count=0
- correct = 0
- idx=0
- batch=[]
- plabe_ls=[]
- batch_size = 50
- cropx = 224
- cropy = 224
- net1 = caffe.Net(args.proto, args.model, caffe.TEST)
- # transformer = caffe.io.Transformer({'data': net1.blobs['data'].data.shape})
- # transformer.set_transpose('data', (2,0,1))
- # transformer.set_mean('data', mean_image[0].mean(1).mean(1))
- # transformer.set_raw_scale('data', 255)
- # transformer.set_channel_swap('data', (2,1,0))
- net1.blobs['data'].reshape(batch_size, 3,224, 224)
- data_blob_shape = net1.blobs['data'].data.shape
- data_blob_shape = list(data_blob_shape)
- #net1.blobs['data'].reshape(batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3])
- i=0
- #mu = np.load('mean.npy')
- mu = np.array([ 104, 117, 123])#imagenet mean
- #reader = db_reader(args.db_path, args.db_type.lower())
- #check and see if its lmdb or leveldb
- if(args.db_type.lower() == 'lmdb'):
- lmdb_env = lmdb.open(args.db_path)
- lmdb_txn = lmdb_env.begin()
- lmdb_cursor = lmdb_txn.cursor()
- for key, value in lmdb_cursor:
- count += 1
- datum = caffe.proto.caffe_pb2.Datum()
- datum.ParseFromString(value)
- label = int(datum.label)
- image = caffe.io.datum_to_array(datum).astype(np.uint8)
- #key,image,label
- #buffer n image
- if(count%5000==0):
- print('{0} samples processed so far'.format(count))
- if(i < batch_size):
- i+=1
- inf= key,image,label
- batch.append(inf)
- #print(key)
- if(i >= batch_size):
- #process n image
- ims=[]
- for x in range(len(batch)):
- img = batch[x][1]
- #img has c,h,w shape! its already gone through transpose and channel swap when it was being saved into lmdb!
- #method I: crop the both the image and mean file
- #ims.append(img[:,0:224,0:224] - mean_image[0][:,0:224,0:224] )
- #Method II : resize the image to the desired size(crop size)
- #img = caffe.io.resize_image(img.transpose(2,1,0), (224, 224))
- #Method III : use center crop just like caffe does in test time
- #center crop
- c,h,w = img.shape
- startx = h//2 - cropx//2
- starty = w//2 - cropy//2
- img = img[:, startx:startx + cropx, starty:starty + cropy]
- #transpose the image so we can subtract from mean, and in the meanwhile, change it to float!
- img = np.array(img.transpose(2,1,0),dtype=np.float32)
- img -= mean_image[0].mean(1).mean(1)
- #transpose back to the original state
- img = img.transpose(2,1,0)
- ims.append(img)
- net1.blobs['data'].data[...] = ims[:]
- out_1 = net1.forward()
- #print('batch processed')
- plabe_ls = out_1['pred']#.argmax(axis=0)
- plbl = np.asarray(plabe_ls)
- #print(plbl)
- #print(plbl.argmax(axis=1))
- plbl = plbl.argmax(axis=1)
- for j in range(len(batch)):
- if (plbl[j] == batch[j][2]):
- correct+=1
- else:
- misclassified.append(batch[j][0])
- predicted_lables.append(plbl[j])
- true_labels.append(batch[j][2])
- batch.clear()
- i=0
- ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified,
- filename='misclassifieds.txt',
- title='Receiver Operating Characteristic_ROC' )
- else:#leveldb
- import leveldb
- db = leveldb.LevelDB(args.db_path)
- for key, value in db.RangeIter():
- datum = caffe.proto.caffe_pb2.Datum()
- datum.ParseFromString(value)
- label = int(datum.label)
- image = caffe.io.datum_to_array(datum).astype(np.uint8)
- #key,image,label
- #buffer n image
- #print('count: ',count)
- if(i < 50):
- i+=1
- inf= key,image,label
- batch.append(inf)
- #print(key)
- if(i >= 50):
- #process n image
- ims=[]
- for x in range(len(batch)):
- img = batch[x][1]
- #img has c,h,w shape! its already gone through transpose and channel swap when it was being saved into lmdb!
- #method I: crop the both the image and mean file
- #ims.append(img[:,0:224,0:224] - mean_image[0][:,0:224,0:224] )
- #Method II : resize the image to the desired size(crop size)
- #img = caffe.io.resize_image(img.transpose(2,1,0), (224, 224))
- #Method III : use center crop just like caffe does in test time
- #center crop
- c,h,w = img.shape
- startx = h//2 - cropx//2
- starty = w//2 - cropy//2
- img = img[:, startx:startx + cropx, starty:starty + cropy]
- #transpose the image so we can subtract from mean, and in the meanwhile, change it to float!
- img = np.array(img.transpose(2,1,0),dtype=np.float32)
- img -= mean_image[0].mean(1).mean(1)
- #transpose back to the original state
- img = img.transpose(2,1,0)
- ims.append(img)
- net1.blobs['data'].data[...] = ims[:]
- out_1 = net1.forward()
- plabe_ls = out_1['pred']#.argmax(axis=0)
- plbl = np.asarray(plabe_ls)
- #print(plbl)
- #print(plbl.argmax(axis=1))
- plbl = plbl.argmax(axis=1)
- for j in range(len(batch)):
- if (plbl[j] == batch[j][2]):
- correct+=1
- else:
- misclassified.append(batch[j][0])
- predicted_lables.append(plbl[j])
- true_labels.append(batch[j][2])
- batch.clear()
- i=0
- ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified,
- filename='misclassifieds.txt',
- title='Receiver Operating Characteristic_ROC' )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement