Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pip install lmdb dar cmd baraye lmdb
- #pip install protobuf
- import sys
- import caffe
- import numpy as np
- import lmdb
- import argparse
- from collections import defaultdict
- from sklearn.metrics import classification_report
- def flat_shape(x):
- "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
- return x.reshape(filter(lambda s: s > 1, x.shape))
- def db_reader(fpath, type='lmdb'):
- if type == 'lmdb':
- return lmdb_reader(fpath)
- else:
- return leveldb_reader(fpath)
- def lmdb_reader(fpath):
- import lmdb
- lmdb_env = lmdb.open(fpath)
- lmdb_txn = lmdb_env.begin()
- lmdb_cursor = lmdb_txn.cursor()
- for key, value in lmdb_cursor:
- datum = caffe.proto.caffe_pb2.Datum()
- datum.ParseFromString(value)
- label = int(datum.label)
- image = caffe.io.datum_to_array(datum).astype(np.uint8)
- yield (key, flat_shape(image), label)
- def leveldb_reader(fpath):
- import leveldb
- db = leveldb.LevelDB(fpath)
- for key, value in db.RangeIter():
- datum = caffe.proto.caffe_pb2.Datum()
- datum.ParseFromString(value)
- label = int(datum.label)
- image = caffe.io.datum_to_array(datum).astype(np.uint8)
- yield (key, flat_shape(image), label)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--proto', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/deploy.prototxt', type=str, required=True)
- parser.add_argument('--model', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/caffe_alexnet_sinatrain_iter_1606.caffemodel', type=str, required=True)
- parser.add_argument('--mean', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/mean_imagetest.binaryproto', type=str, required=True)
- #group = parser.add_mutually_exclusive_group(required=True)
- parser.add_argument('--db_type', help='lmdb', type=str, required=True)
- parser.add_argument('--db_path', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/Mydataset_test_lmdb', type=str, required=True)
- args = parser.parse_args()
- # Extract mean from the mean image file
- mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
- f = open(args.mean, 'rb')
- mean_blobproto_new.ParseFromString(f.read())
- mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
- f.close()
- # CNN reconstruction and loading the trained weights
- net = caffe.Net(args.proto, args.model, caffe.TEST)
- # You may also use set_mode_cpu() if you didnt compile caffe with gpu support
- caffe.set_mode_gpu()
- print ("args", vars(args))
- reader = db_reader(args.db_path, args.db_type.lower())
- predicted_lables=[]
- true_labels = []
- class_names = ['b','p','e','n']
- for i, image, label in reader:
- image_caffe = image.reshape(1, *image.shape)
- out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
- plabel = int(out['prob'][0].argmax(axis=0))
- predicted_lables.append(plabel)
- true_labels.append(label)
- print(i,' processed!')
- print( classification_report(y_true=true_labels,
- y_pred=predicted_lables,
- target_names=class_names))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement