Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- Model init time: 0.77 sec
- Processing time: 37.58 sec
- Model init time: 0.76 sec
- Processing time: 20.16 sec
- Model init time: 0.78 sec
- Processing time: 39.14 sec
- import os
- import glob
- import time
- import argparse
- from multiprocessing.pool import ThreadPool
- import multiprocessing
- import itertools
- import tensorflow as tf
- import numpy as np
- from tqdm import tqdm
- import cv2
- MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'
- def get_image_filepaths(dataset_dir):
- if not os.path.isdir(dataset_dir):
- raise Exception(dataset_dir, 'not dir!')
- img_filepaths = []
- extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
- for ext in extensions:
- img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))
- return img_filepaths
- class ModelWrapper():
- def __init__(self, model_filepath):
- # TODO: estimate this from graph itself
- # Hardcoded for inception_v3_2016_08_28_frozen.pb
- self.input_node_names = ['input']
- self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
- self.input_img_w = 299
- self.input_img_h = 299
- input_tensor_names = [name + ":0" for name in self.input_node_names]
- output_tensor_names = [name + ":0" for name in self.output_node_names]
- self.graph = self.load_graph(model_filepath)
- self.inputs = []
- for input_tensor_name in input_tensor_names:
- self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))
- self.outputs = []
- for output_tensor_name in output_tensor_names:
- self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))
- config_proto = tf.ConfigProto(device_count={'GPU': 0},
- intra_op_parallelism_threads=1,
- inter_op_parallelism_threads=1)
- self.sess = tf.Session(graph=self.graph, config=config_proto)
- def load_graph(self, model_filepath):
- # Expects frozen graph in .pb format
- with tf.gfile.GFile(model_filepath, "rb") as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- with tf.Graph().as_default() as graph:
- tf.import_graph_def(graph_def, name="")
- return graph
- def predict(self, img):
- h, w, c = img.shape
- if h != self.input_img_h or w != self.input_img_w:
- img = cv2.resize(img, (self.input_img_w, self.input_img_h))
- batch = img[np.newaxis, ...]
- feed_dict = {self.inputs[0] : batch}
- outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)
- return outputs
- def process_single_file(args):
- model, img_filepath = args
- img = cv2.imread(img_filepath)
- output = model.predict(img)
- def process_dataset(dataset_dir):
- img_filepaths = get_image_filepaths(dataset_dir)
- start = time.time()
- model = ModelWrapper(MODEL_FILEPATH)
- print('Model init time:', round(time.time() - start, 2), 'sec')
- start = time.time()
- n_cpu = multiprocessing.cpu_count()
- for _ in tqdm(ThreadPool(n_cpu).imap_unordered(process_single_file,
- zip(itertools.repeat(model), img_filepaths)),
- total=len(img_filepaths)):
- pass
- print('Processing time:', round(time.time() - start, 2), 'sec')
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(dest='dataset_dir')
- args = parser.parse_args()
- process_dataset(args.dataset_dir)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement