Guest User

Untitled

a guest
Jul 18th, 2018
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.75 KB | None | 0 0
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import nested_scopes
  4. from __future__ import print_function
  5.  
  6. import time
  7. import argparse
  8. from argparse import RawTextHelpFormatter
  9.  
  10. from grpc.beta import implementations
  11. import numpy as np
  12. import tensorflow as tf
  13.  
  14. from tensorflow_serving.apis import predict_pb2
  15. from tensorflow_serving.apis import prediction_service_pb2
  16.  
  17. from utils import load_image_into_numpy_array
  18. from utils import visualize_serving_bounding_boxes
  19.  
  20. tf.logging.set_verbosity(tf.logging.INFO)
  21.  
  22.  
  23. def load_input_tensor(input_image, input_type):
  24. if input_type == 'image_tensor':
  25. image_np = load_image_into_numpy_array(input_image)
  26. image_np_expanded = np.expand_dims(image_np, axis=0).astype(np.float32)
  27. tensor = tf.contrib.util.make_tensor_proto(image_np_expanded)
  28. elif input_type == 'encoded_image_string_tensor':
  29. with open(input_image, 'rb') as f:
  30. data = f.read()
  31. tensor = tf.contrib.util.make_tensor_proto(data, shape=[1])
  32. else:
  33. raise ValueError("Unsupported input type: %s" % input_type)
  34. return tensor
  35.  
  36.  
  37. def main(args):
  38. host, port = args.server.split(':')
  39. channel = implementations.insecure_channel(host, int(port))
  40. stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
  41.  
  42. # Create prediction request object
  43. start_ts = time.time()
  44. request = predict_pb2.PredictRequest()
  45.  
  46. # Specify model name (must be the same as when the TensorFlow serving
  47. # was started)
  48. request.model_spec.name = args.model_name
  49.  
  50. input_tensor = load_input_tensor(args.input_image, args.input_type)
  51. request.inputs['inputs'].CopyFrom(input_tensor)
  52. tf.logging.info("Image load time: %s sec" % (time.time() - start_ts))
  53.  
  54. # Call the prediction server
  55. start_ts = time.time()
  56. result = stub.Predict(request, 60.0) # 60 secs timeout
  57. tf.logging.info("Inference time: %s sec" % (time.time() - start_ts))
  58. visualize_serving_bounding_boxes(args.output_directory, args.input_image, args.label_map,
  59. args.max_classes, result)
  60.  
  61.  
  62. if __name__ == '__main__':
  63. parser = argparse.ArgumentParser(description="Object detection client.",
  64. formatter_class=RawTextHelpFormatter)
  65. parser.add_argument('--server',
  66. type=str,
  67. required=True,
  68. help='PredictionService host:port')
  69. parser.add_argument('--model_name',
  70. type=str,
  71. required=True,
  72. help='Name of the model')
  73. parser.add_argument('--input_image',
  74. type=str,
  75. required=True,
  76. help='Path to input image')
  77. parser.add_argument('--output_directory',
  78. type=str,
  79. required=True,
  80. help='Path to output directory')
  81. parser.add_argument('--label_map',
  82. type=str,
  83. required=True,
  84. help='Path to label map file')
  85. parser.add_argument('--max_classes',
  86. type=int,
  87. default=100,
  88. help='Maximum number of classes')
  89. parser.add_argument('--input_type',
  90. choices=['image_tensor',
  91. 'encoded_image_string_tensor',
  92. 'tf_example'],
  93. default='image_tensor',
  94. help='Type of input node. Can be '
  95. 'one of [`image_tensor`, '
  96. '`encoded_image_string_tensor`, '
  97. '`tf_example`]')
  98.  
  99. args = parser.parse_args()
  100. main(args)
Add Comment
Please, Sign In to add comment