Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.52 KB | None | 0 0
  1. import tensorflow as tf
  2. import cv2
  3. import os
  4.  
  5. def get_frozen_graph(graph_file):
  6. """Read Frozen Graph file from disk."""
  7. with tf.gfile.FastGFile(graph_file, "rb") as f:
  8. graph_def = tf.GraphDef()
  9. graph_def.ParseFromString(f.read())
  10. return graph_def
  11.  
  12. # The TensorRT inference graph file downloaded from Colab or your local machine.
  13. pb_fname = os.path.join(os.getcwd(), "faster_rcnn_inception_resnet_v2_atrous_coco_2018_01_28", "frozen_inference_graph.pb")
  14. trt_graph = get_frozen_graph(pb_fname)
  15.  
  16. input_names = ['image_tensor']
  17.  
  18. # Create session and load graph
  19. tf_config = tf.ConfigProto()
  20. tf_config.gpu_options.allow_growth = True
  21. tf_sess = tf.Session(config=tf_config)
  22. tf.import_graph_def(trt_graph, name='')
  23.  
  24. tf_input = tf_sess.graph.get_tensor_by_name(input_names[0] + ':0')
  25. tf_scores = tf_sess.graph.get_tensor_by_name('detection_scores:0')
  26. tf_boxes = tf_sess.graph.get_tensor_by_name('detection_boxes:0')
  27. tf_classes = tf_sess.graph.get_tensor_by_name('detection_classes:0')
  28. tf_num_detections = tf_sess.graph.get_tensor_by_name('num_detections:0')
  29.  
  30.  
  31. IMAGE_PATH = os.path.join(os.getcwd(), "testimages", "000002_491724089556.png")
  32. image = cv2.imread(IMAGE_PATH)
  33. image = cv2.resize(image, (300, 300))
  34.  
  35. scores, boxes, classes, num_detections = tf_sess.run([tf_scores, tf_boxes, tf_classes, tf_num_detections], feed_dict={
  36. tf_input: image[None, ...]
  37. })
  38. boxes = boxes[0] # index by 0 to remove batch dimension
  39. scores = scores[0]
  40. classes = classes[0]
  41. num_detections = int(num_detections[0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement