Advertisement
Guest User

Untitled

a guest
Apr 21st, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.54 KB | None | 0 0
  1. import multiprocessing as mp
  2. import time, logging, itertools
  3. import cv2 as cv
  4. import numpy as np
  5. import os
  6. import six.moves.urllib as urllib
  7. import sys
  8. import tarfile
  9. import tensorflow as tf
  10. import zipfile
  11.  
  12. from distutils.version import StrictVersion
  13. from collections import defaultdict
  14. from io import StringIO
  15. from matplotlib import pyplot as plt
  16. from PIL import Image
  17.  
  18. from object_detection.utils import ops as utils_ops
  19.  
  20. from object_detection.utils import label_map_util
  21.  
  22. from object_detection.utils import visualization_utils as vis_util
  23.  
  24.  
  25. def run_inference_for_single_image(sess, image, graph):
  26.   with graph.as_default():
  27.       # Get handles to input and output tensors
  28.       ops = tf.get_default_graph().get_operations()
  29.       all_tensor_names = {output.name for op in ops for output in op.outputs}
  30.       tensor_dict = {}
  31.       for key in [
  32.           'num_detections', 'detection_boxes', 'detection_scores',
  33.           'detection_classes', 'detection_masks'
  34.       ]:
  35.         tensor_name = key + ':0'
  36.         if tensor_name in all_tensor_names:
  37.           tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
  38.               tensor_name)
  39.       if 'detection_masks' in tensor_dict:
  40.         # The following processing is only for single image
  41.         detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  42.         detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  43.         # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  44.         real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  45.         detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  46.         detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  47.         detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  48.             detection_masks, detection_boxes, image.shape[1], image.shape[2])
  49.         detection_masks_reframed = tf.cast(
  50.             tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  51.         # Follow the convention by adding back the batch dimension
  52.         tensor_dict['detection_masks'] = tf.expand_dims(
  53.             detection_masks_reframed, 0)
  54.       image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
  55.  
  56.       # Run inference
  57.       output_dict = sess.run(tensor_dict,
  58.                              feed_dict={image_tensor: image})
  59.  
  60.       # all outputs are float32 numpy arrays, so convert types as appropriate
  61.       output_dict['num_detections'] = int(output_dict['num_detections'][0])
  62.       output_dict['detection_classes'] = output_dict[
  63.           'detection_classes'][0].astype(np.uint8)
  64.       output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  65.       output_dict['detection_scores'] = output_dict['detection_scores'][0]
  66.       if 'detection_masks' in output_dict:
  67.         output_dict['detection_masks'] = output_dict['detection_masks'][0]
  68.   return output_dict
  69.  
  70.  
  71.  
  72. # Holds a frame (numpy array) and creates current timestamp and autoincremental sequence number
  73. class frame_container:
  74.     counter = itertools.count()
  75.     def __init__(self, frame):
  76.         self.frame = frame
  77.         self.timestamp = time.time()
  78.         self.seq = next(self.counter)
  79.  
  80. # Plays the last frame generated by the object detection model
  81. def video_player(q):
  82.     while True:
  83.         if not q.empty():
  84.             frame = q.get()
  85.             if debug == 1: print("[VP] Received frame #{}".format(frame.seq))
  86.         else:
  87.             continue
  88.         cv.imshow('SSD Output', frame.frame)
  89.         if cv.waitKey(1000 // fps) & 0xFF == ord('q'):
  90.             break
  91.  
  92. # Puts frames into the model process queue, downsampling FPS from 30 to 5
  93. def producer(q):
  94.     cap = cv.VideoCapture('road.mp4')
  95.     n_frame = 0
  96.     while (cap.isOpened()):
  97.         ret, frame = cap.read()
  98.         if not ret:
  99.             cap = cv.VideoCapture('road.mp4')
  100.             continue
  101.         dims = frame.shape[:2]
  102.         max_side = np.max(dims)
  103.         res = cv.resize(frame, None, fx=500 / max_side, fy=500 / max_side, interpolation=cv.INTER_CUBIC)
  104.         if not q.full():
  105.             if n_frame == 0:
  106.                 frame = frame_container(res)
  107.                 q.put(frame, False)
  108.                 # cv.imshow('Original Video', frame.frame)
  109.                 if cv.waitKey(1000 // fps) & 0xFF == ord('q'):
  110.                     break
  111.                 if debug == 1: print("[P] Produced frame #{}".format(frame.seq))
  112.         n_frame = (n_frame + 1) % 6
  113.     cap.release()
  114.     cv.destroyAllWindows()
  115.  
  116. # Waits for #batch_size frames and runs them through the object detection model, sending output to the video player process
  117. def consumer(q, output_q):
  118.     # What model to download.
  119.     MODEL_NAME = 'ssd_mobilenet_v1_coco_2018_01_28'
  120.     MODEL_FILE = MODEL_NAME + '.tar.gz'
  121.     DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
  122.  
  123.     # Path to frozen detection graph. This is the actual model that is used for the object detection.
  124.     PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'
  125.  
  126.     # List of the strings that is used to add correct label for each box.
  127.     PATH_TO_LABELS = os.path.abspath('../data/mscoco_label_map.pbtxt')
  128.  
  129.     # Uncomment to download new model
  130.     # opener = urllib.request.URLopener()
  131.     # opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  132.     # tar_file = tarfile.open(MODEL_FILE)
  133.     # for file in tar_file.getmembers():
  134.     #     file_name = os.path.basename(file.name)
  135.     #     if 'frozen_inference_graph.pb' in file_name:
  136.     #         tar_file.extract(file, os.getcwd())
  137.  
  138.     detection_graph = tf.Graph()
  139.     with detection_graph.as_default():
  140.         od_graph_def = tf.GraphDef()
  141.         with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
  142.             serialized_graph = fid.read()
  143.             od_graph_def.ParseFromString(serialized_graph)
  144.             tf.import_graph_def(od_graph_def, name='')
  145.  
  146.     category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
  147.  
  148.     model_output_imgs = []
  149.     input_images = []
  150.     inference_samples_num = 4
  151.     beta = 1 - 1 / inference_samples_num
  152.     inference_avg = -1
  153.     inference_current = -1
  154.     drift = 0
  155.  
  156.     with tf.Session(graph=detection_graph) as sess:
  157.         while True:
  158.  
  159.             if not q.empty():
  160.                 if debug == 1: print("[C] Drift: {}".format(drift))
  161.                 frame = q.get()
  162.                 input_images.append(frame)
  163.  
  164.                 # Drift is how much behind schedule the current frame is.
  165.                 # If more than one entire frame behind, skip this frame.
  166.                 if drift >= 1 / fps:
  167.                     drift -= 1 / fps
  168.                     input_images.pop(0)
  169.                     if debug >= 1: print("[C] Skipping frame #{}".format(frame.seq))
  170.                 else:
  171.                     drift += max(inference_current / batch_size - 1 / fps, 0)
  172.             if len(input_images) >= batch_size:
  173.                 tick = time.time()
  174.                 if debug == 1: print("[C] Q, I, O, P size: ({}, {}, {}, {})".format(q.qsize(), len(input_images), len(model_output_imgs), output_q.qsize()))
  175.                 for i in range(batch_size):
  176.                     minitick = time.time()
  177.                     current_frame = input_images.pop(0)
  178.                     image_np = current_frame.frame
  179.                     # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  180.                     image_np_expanded = np.expand_dims(image_np, axis=0)
  181.                     # Actual detection.
  182.                     output_dict = run_inference_for_single_image(sess, image_np_expanded, detection_graph)
  183.                     # Visualization of the results of a detection.
  184.                     vis_util.visualize_boxes_and_labels_on_image_array(
  185.                         image_np,
  186.                         output_dict['detection_boxes'],
  187.                         output_dict['detection_classes'],
  188.                         output_dict['detection_scores'],
  189.                         category_index,
  190.                         instance_masks=output_dict.get('detection_masks'),
  191.                         use_normalized_coordinates=True,
  192.                         line_thickness=2)
  193.                     current_frame.frame = image_np
  194.                     minitock = time.time()
  195.                     output_q.put(current_frame, False)
  196.  
  197.                     # Fake delay so that batch inference time is >2
  198.                     # if minitock - minitick < 2 / batch_size:
  199.                     #     time.sleep(2 / batch_size - (time.time() - minitick))
  200.                 tock = time.time()
  201.                 inference_current = (tock - tick)
  202.                 if inference_avg < 0:
  203.                     inference_avg = inference_current
  204.                 inference_avg = beta * inference_avg + (1 - beta) * inference_current
  205.                 print("[C] Inference current batch: {}".format(inference_current))
  206.                 # print("[C] Inference average: {}".format(inference_avg))
  207.  
  208.     cv.destroyAllWindows()
  209.  
  210. # logger = mp.log_to_stderr(logging.DEBUG)
  211.  
  212. fps = 5
  213. batch_size = 8
  214. debug = 2
  215. # 0: inference time
  216. # 1: everything
  217. # 2: skipped frames
  218.  
  219. consumer_in_q = mp.Queue(20)
  220. consumer_out_q = mp.Queue(20)
  221.  
  222. #Producer process
  223. p_proc = mp.Process(target=producer, args=(consumer_in_q,))
  224. #Consumer process
  225. c_proc = mp.Process(target=consumer, args=(consumer_in_q, consumer_out_q))
  226. #Video player process
  227. player = mp.Process(target=video_player, args=(consumer_out_q,))
  228.  
  229. p_proc.start()
  230. c_proc.start()
  231. player.start()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement