SHARE
TWEET

Untitled

a guest Jun 17th, 2019 56 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. Model init time: 0.77 sec
  2. Processing time: 37.58 sec
  3.      
  4. Model init time: 0.76 sec
  5. Processing time: 20.16 sec
  6.      
  7. Model init time: 0.78 sec
  8. Processing time: 39.14 sec
  9.      
  10. import os
  11. import glob
  12. import time
  13. import argparse
  14. from multiprocessing.pool import ThreadPool
  15. import multiprocessing
  16. import itertools
  17.  
  18. import tensorflow as tf
  19. import numpy as np
  20. from tqdm import tqdm
  21. import cv2
  22.  
  23. MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'
  24.  
  25. def get_image_filepaths(dataset_dir):
  26.     if not os.path.isdir(dataset_dir):
  27.         raise Exception(dataset_dir, 'not dir!')
  28.  
  29.     img_filepaths = []
  30.     extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
  31.     for ext in extensions:
  32.         img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))
  33.  
  34.     return img_filepaths
  35.  
  36.  
  37. class ModelWrapper():
  38.     def __init__(self, model_filepath):
  39.         # TODO: estimate this from graph itself
  40.         # Hardcoded for inception_v3_2016_08_28_frozen.pb
  41.         self.input_node_names = ['input']
  42.         self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
  43.         self.input_img_w = 299
  44.         self.input_img_h = 299
  45.  
  46.         input_tensor_names = [name + ":0" for name in self.input_node_names]
  47.         output_tensor_names = [name + ":0" for name in self.output_node_names]
  48.  
  49.         self.graph = self.load_graph(model_filepath)
  50.  
  51.         self.inputs = []
  52.         for input_tensor_name in input_tensor_names:
  53.             self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))
  54.  
  55.         self.outputs = []
  56.         for output_tensor_name in output_tensor_names:
  57.             self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))
  58.  
  59.         config_proto = tf.ConfigProto(device_count={'GPU': 0},
  60.                                       intra_op_parallelism_threads=1,
  61.                                       inter_op_parallelism_threads=1)
  62.         self.sess = tf.Session(graph=self.graph, config=config_proto)
  63.  
  64.     def load_graph(self, model_filepath):
  65.         # Expects frozen graph in .pb format
  66.         with tf.gfile.GFile(model_filepath, "rb") as f:
  67.             graph_def = tf.GraphDef()
  68.             graph_def.ParseFromString(f.read())
  69.         with tf.Graph().as_default() as graph:
  70.             tf.import_graph_def(graph_def, name="")
  71.         return graph
  72.  
  73.     def predict(self, img):
  74.         h, w, c = img.shape
  75.         if h != self.input_img_h or w != self.input_img_w:
  76.             img = cv2.resize(img, (self.input_img_w, self.input_img_h))
  77.  
  78.         batch = img[np.newaxis, ...]
  79.         feed_dict = {self.inputs[0] : batch}
  80.         outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)
  81.  
  82.         return outputs
  83.  
  84.  
  85. def process_single_file(args):
  86.     model, img_filepath = args
  87.  
  88.     img = cv2.imread(img_filepath)
  89.     output = model.predict(img)
  90.  
  91.  
  92. def process_dataset(dataset_dir):
  93.     img_filepaths = get_image_filepaths(dataset_dir)
  94.  
  95.     start = time.time()
  96.     model = ModelWrapper(MODEL_FILEPATH)
  97.     print('Model init time:', round(time.time() - start, 2), 'sec')
  98.  
  99.     start = time.time()
  100.     n_cpu = multiprocessing.cpu_count()
  101.     for _ in tqdm(ThreadPool(n_cpu).imap_unordered(process_single_file,
  102.                                                    zip(itertools.repeat(model), img_filepaths)),
  103.                                                    total=len(img_filepaths)):
  104.         pass
  105.     print('Processing time:', round(time.time() - start, 2), 'sec')
  106.  
  107.  
  108. if __name__ == "__main__":
  109.     parser = argparse.ArgumentParser()
  110.     parser.add_argument(dest='dataset_dir')
  111.     args = parser.parse_args()
  112.  
  113.     process_dataset(args.dataset_dir)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top