Guest User

Untitled

a guest
Oct 1st, 2018
266
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.90 KB | None | 0 0
  1. # load_model.py
  2. import mxnet as mx
  3. import numpy as np
  4. import picamera
  5. import cv2, os, urllib2, argparse, time
  6. from collections import namedtuple
  7. Batch = namedtuple('Batch', ['data'])
  8.  
  9.  
  10. class ImagenetModel(object):
  11.  
  12.     """
  13.    Loads a pre-trained model locally or from an external URL and returns an MXNet graph that is ready for prediction
  14.    """
  15.     def __init__(self, synset_path, network_prefix, params_url=None, symbol_url=None, synset_url=None, context=mx.cpu(), label_names=['prob_label'], input_shapes=[('data', (1,3,224,224))]):
  16.  
  17.         # Download the symbol set and network if URLs are provided
  18.         if params_url is not None:
  19.             print "fetching params from "+params_url
  20.             fetched_file = urllib2.urlopen(params_url)
  21.             with open(network_prefix+"-0000.params",'wb') as output:
  22.                 output.write(fetched_file.read())
  23.  
  24.         if symbol_url is not None:
  25.             print "fetching symbols from "+symbol_url
  26.             fetched_file = urllib2.urlopen(symbol_url)
  27.             with open(network_prefix+"-symbol.json",'wb') as output:
  28.                 output.write(fetched_file.read())
  29.  
  30.         if synset_url is not None:
  31.             print "fetching synset from "+synset_url
  32.             fetched_file = urllib2.urlopen(synset_url)
  33.             with open(synset_path,'wb') as output:
  34.                 output.write(fetched_file.read())
  35.  
  36.         # Load the symbols for the networks
  37.         with open(synset_path, 'r') as f:
  38.             self.synsets = [l.rstrip() for l in f]
  39.  
  40.         # Load the network parameters from default epoch 0
  41.         sym, arg_params, aux_params = mx.model.load_checkpoint(network_prefix, 0)
  42.  
  43.         # Load the network into an MXNet module and bind the corresponding parameters
  44.         self.mod = mx.mod.Module(symbol=sym, label_names=label_names, context=context)
  45.         self.mod.bind(for_training=False, data_shapes= input_shapes)
  46.         self.mod.set_params(arg_params, aux_params)
  47.         self.camera = None
  48.  
  49.     """
  50.    Takes in an image, reshapes it, and runs it through the loaded MXNet graph for inference returning the N top labels from the softmax
  51.    """
  52.     def predict_from_file(self, filename, reshape=(224, 224), N=5):
  53.  
  54.         topN = []
  55.  
  56.         # Switch RGB to BGR format (which ImageNet networks take)
  57.         img = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
  58.         if img is None:
  59.             return topN
  60.  
  61.         # Resize image to fit network input
  62.         img = cv2.resize(img, reshape)
  63.         img = np.swapaxes(img, 0, 2)
  64.         img = np.swapaxes(img, 1, 2)
  65.         img = img[np.newaxis, :]
  66.  
  67.         # Run forward on the image
  68.         self.mod.forward(Batch([mx.nd.array(img)]))
  69.         prob = self.mod.get_outputs()[0].asnumpy()
  70.         prob = np.squeeze(prob)
  71.  
  72.         # Extract the top N predictions from the softmax output
  73.         a = np.argsort(prob)[::-1]
  74.         for i in a[0:N]:
  75.             print('probability=%f, class=%s' %(prob[i], self.synsets[i]))
  76.             topN.append((prob[i], self.synsets[i]))
  77.         return topN
  78.  
  79.     """
  80.    Captures an image from the PiCamera, then sends it for prediction
  81.    """
  82.     def predict_from_cam(self, capfile='cap.jpg', reshape=(224, 224), N=5):
  83.         if self.camera is None:
  84.             self.camera = picamera.PiCamera()
  85.  
  86.         # Show quick preview of what's being captured
  87.         self.camera.start_preview()
  88.         time.sleep(3)
  89.         self.camera.capture(capfile)
  90.         self.camera.stop_preview()
  91.  
  92.         return self.predict_from_file(capfile)
  93.  
  94.  
  95. if __name__ == "__main__":
  96.     parser = argparse.ArgumentParser(description="pull and load pre-trained resnet model to classify one image")
  97.     parser.add_argument('--img', type=str, default='cam', help='input image for classification, if this is cam it captures from the PiCamera')
  98.     parser.add_argument('--prefix', type=str, default='squeezenet_v1.1', help='the prefix of the pre-trained model')
  99.     parser.add_argument('--label-name', type=str, default='prob_label', help='the name of the last layer in the loaded network (usually softmax_label)')
  100.     parser.add_argument('--synset', type=str, default='synset.txt', help='the path of the synset for the model')
  101.     parser.add_argument('--params-url', type=str, default=None, help='the (optional) url to pull the network parameter file from')
  102.     parser.add_argument('--symbol-url', type=str, default=None, help='the (optional) url to pull the network symbol JSON from')
  103.     parser.add_argument('--synset-url', type=str, default=None, help='the (optional) url to pull the synset file from')
  104.     args = parser.parse_args()
  105.     mod = ImagenetModel(args.synset, args.prefix, label_names=[args.label_name], params_url=args.params_url, symbol_url=args.symbol_url, synset_url=args.synset_url)
  106.     print "predicting on "+args.img
  107.     if args.img == "cam":
  108.         print mod.predict_from_cam()
  109.     else:
  110.         print mod.predict_from_file(args.img)
Add Comment
Please, Sign In to add comment