Advertisement
Guest User

Untitled

a guest
Mar 7th, 2017
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.21 KB | None | 0 0
  1. #pip install lmdb dar cmd baraye lmdb
  2. #pip install protobuf
  3. import sys
  4. import caffe
  5. import numpy as np
  6. import lmdb
  7. import argparse
  8. from collections import defaultdict
  9. from sklearn.metrics import classification_report
  10.  
  11. def flat_shape(x):
  12. "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
  13. return x.reshape(filter(lambda s: s > 1, x.shape))
  14.  
  15. def db_reader(fpath, type='lmdb'):
  16. if type == 'lmdb':
  17. return lmdb_reader(fpath)
  18. else:
  19. return leveldb_reader(fpath)
  20.  
  21.  
  22. def lmdb_reader(fpath):
  23. import lmdb
  24. lmdb_env = lmdb.open(fpath)
  25. lmdb_txn = lmdb_env.begin()
  26. lmdb_cursor = lmdb_txn.cursor()
  27.  
  28. for key, value in lmdb_cursor:
  29. datum = caffe.proto.caffe_pb2.Datum()
  30. datum.ParseFromString(value)
  31. label = int(datum.label)
  32. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  33. yield (key, flat_shape(image), label)
  34.  
  35. def leveldb_reader(fpath):
  36. import leveldb
  37. db = leveldb.LevelDB(fpath)
  38.  
  39. for key, value in db.RangeIter():
  40. datum = caffe.proto.caffe_pb2.Datum()
  41. datum.ParseFromString(value)
  42. label = int(datum.label)
  43. image = caffe.io.datum_to_array(datum).astype(np.uint8)
  44. yield (key, flat_shape(image), label)
  45.  
  46.  
  47. if __name__ == "__main__":
  48. parser = argparse.ArgumentParser()
  49. parser.add_argument('--proto', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/deploy.prototxt', type=str, required=True)
  50. parser.add_argument('--model', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/caffe_alexnet_sinatrain_iter_1606.caffemodel', type=str, required=True)
  51. parser.add_argument('--mean', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/mean_imagetest.binaryproto', type=str, required=True)
  52. #group = parser.add_mutually_exclusive_group(required=True)
  53. parser.add_argument('--db_type', help='lmdb', type=str, required=True)
  54. parser.add_argument('--db_path', help='E:/CAFFE/caffe-windows/models/blvc_alexnets/Mydataset_test_lmdb', type=str, required=True)
  55. args = parser.parse_args()
  56.  
  57. # Extract mean from the mean image file
  58. mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
  59. f = open(args.mean, 'rb')
  60. mean_blobproto_new.ParseFromString(f.read())
  61. mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
  62. f.close()
  63.  
  64. # CNN reconstruction and loading the trained weights
  65. net = caffe.Net(args.proto, args.model, caffe.TEST)
  66. # You may also use set_mode_cpu() if you didnt compile caffe with gpu support
  67. caffe.set_mode_gpu()
  68.  
  69. print ("args", vars(args))
  70.  
  71. reader = db_reader(args.db_path, args.db_type.lower())
  72.  
  73. predicted_lables=[]
  74. true_labels = []
  75. class_names = ['b','p','e','n']
  76.  
  77. for i, image, label in reader:
  78. image_caffe = image.reshape(1, *image.shape)
  79. out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
  80. plabel = int(out['prob'][0].argmax(axis=0))
  81. predicted_lables.append(plabel)
  82. true_labels.append(label)
  83. print(i,' processed!')
  84.  
  85. print( classification_report(y_true=true_labels,
  86. y_pred=predicted_lables,
  87. target_names=class_names))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement