Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import nested_scopes
- from __future__ import print_function
- import time
- import argparse
- from argparse import RawTextHelpFormatter
- from grpc.beta import implementations
- import numpy as np
- import tensorflow as tf
- from tensorflow_serving.apis import predict_pb2
- from tensorflow_serving.apis import prediction_service_pb2
- from utils import load_image_into_numpy_array
- from utils import visualize_serving_bounding_boxes
- tf.logging.set_verbosity(tf.logging.INFO)
- def load_input_tensor(input_image, input_type):
- if input_type == 'image_tensor':
- image_np = load_image_into_numpy_array(input_image)
- image_np_expanded = np.expand_dims(image_np, axis=0).astype(np.float32)
- tensor = tf.contrib.util.make_tensor_proto(image_np_expanded)
- elif input_type == 'encoded_image_string_tensor':
- with open(input_image, 'rb') as f:
- data = f.read()
- tensor = tf.contrib.util.make_tensor_proto(data, shape=[1])
- else:
- raise ValueError("Unsupported input type: %s" % input_type)
- return tensor
- def main(args):
- host, port = args.server.split(':')
- channel = implementations.insecure_channel(host, int(port))
- stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
- # Create prediction request object
- start_ts = time.time()
- request = predict_pb2.PredictRequest()
- # Specify model name (must be the same as when the TensorFlow serving
- # was started)
- request.model_spec.name = args.model_name
- input_tensor = load_input_tensor(args.input_image, args.input_type)
- request.inputs['inputs'].CopyFrom(input_tensor)
- tf.logging.info("Image load time: %s sec" % (time.time() - start_ts))
- # Call the prediction server
- start_ts = time.time()
- result = stub.Predict(request, 60.0) # 60 secs timeout
- tf.logging.info("Inference time: %s sec" % (time.time() - start_ts))
- visualize_serving_bounding_boxes(args.output_directory, args.input_image, args.label_map,
- args.max_classes, result)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description="Object detection client.",
- formatter_class=RawTextHelpFormatter)
- parser.add_argument('--server',
- type=str,
- required=True,
- help='PredictionService host:port')
- parser.add_argument('--model_name',
- type=str,
- required=True,
- help='Name of the model')
- parser.add_argument('--input_image',
- type=str,
- required=True,
- help='Path to input image')
- parser.add_argument('--output_directory',
- type=str,
- required=True,
- help='Path to output directory')
- parser.add_argument('--label_map',
- type=str,
- required=True,
- help='Path to label map file')
- parser.add_argument('--max_classes',
- type=int,
- default=100,
- help='Maximum number of classes')
- parser.add_argument('--input_type',
- choices=['image_tensor',
- 'encoded_image_string_tensor',
- 'tf_example'],
- default='image_tensor',
- help='Type of input node. Can be '
- 'one of [`image_tensor`, '
- '`encoded_image_string_tensor`, '
- '`tf_example`]')
- args = parser.parse_args()
- main(args)
Add Comment
Please, Sign In to add comment