Guest User

Untitled

a guest
Nov 21st, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.94 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3.  
  4. # # Misty Object Detection Websocket server
  5. # This script needs to be run from the tensorflow object_detection library folder.
  6. # For information on setting up the object_detection library visit
  7. # https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
  8.  
  9. # # Imports
  10. import collections
  11. import numpy as np
  12. import os
  13. import six.moves.urllib as urllib
  14. import sys
  15. import tarfile
  16. import tensorflow as tf
  17. import zipfile
  18. import base64
  19. import re
  20. from io import BytesIO
  21. import asyncio
  22. import websockets
  23. import json
  24.  
  25. from distutils.version import StrictVersion
  26. from collections import defaultdict
  27. from io import StringIO
  28. from matplotlib import pyplot as plt
  29. from PIL import Image
  30.  
  31. # This is needed since the notebook is stored in the object_detection folder.
  32. sys.path.append("..")
  33. from object_detection.utils import ops as utils_ops
  34.  
  35. if StrictVersion(tf.__version__) < StrictVersion('1.12.0'):
  36.   raise ImportError('Please upgrade your TensorFlow installation to v1.12.*.')
  37.  
  38.  
  39. # ## Env setup
  40.  
  41. # ## Object detection imports
  42. # Here are the imports from the object detection module.
  43.  
  44. from utils import label_map_util
  45.  
  46. from utils import visualization_utils as vis_util
  47.  
  48.  
  49. # # Model preparation
  50.  
  51. # ## Variables
  52. #
  53. # Any model exported using the `export_inference_graph.py` tool can be loaded here simply by changing `PATH_TO_FROZEN_GRAPH` to point to a new .pb file.  
  54. #
  55. # By default we use an "SSD with Mobilenet" model here. See the [detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) for a list of other models that can be run out-of-the-box with varying speeds and accuracies.
  56.  
  57. # What model to download.
  58. MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
  59. MODEL_FILE = MODEL_NAME + '.tar.gz'
  60. DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
  61.  
  62. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  63. PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'
  64.  
  65. # List of the strings that is used to add correct label for each box.
  66. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
  67.  
  68. # ## Download Model
  69. opener = urllib.request.URLopener()
  70. opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  71. tar_file = tarfile.open(MODEL_FILE)
  72. for file in tar_file.getmembers():
  73.   file_name = os.path.basename(file.name)
  74.   if 'frozen_inference_graph.pb' in file_name:
  75.     tar_file.extract(file, os.getcwd())
  76.  
  77.  
  78. # ## Load a (frozen) Tensorflow model into memory.
  79. detection_graph = tf.Graph()
  80. with detection_graph.as_default():
  81.   od_graph_def = tf.GraphDef()
  82.   with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
  83.     serialized_graph = fid.read()
  84.     od_graph_def.ParseFromString(serialized_graph)
  85.     tf.import_graph_def(od_graph_def, name='')
  86.  
  87.  
  88. # ## Loading label map
  89. # Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine
  90. category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
  91.  
  92.  
  93. # ## Helper code
  94. def load_image_into_numpy_array(image):
  95.   (im_width, im_height) = image.size
  96.   return np.array(image.getdata()).reshape(
  97.       (im_height, im_width, 3)).astype(np.uint8)
  98.  
  99.  
  100. # # Detection
  101.  
  102. # use this when loading saved images
  103. PATH_TO_TEST_IMAGES_DIR = 'test_images'
  104. TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
  105.  
  106. # Size, in inches, of the output images.
  107. IMAGE_SIZE = (12, 8)
  108.  
  109.  
  110. def run_inference_for_single_image(image, graph):
  111.   with graph.as_default():
  112.     with tf.Session() as sess:
  113.       # Get handles to input and output tensors
  114.       ops = tf.get_default_graph().get_operations()
  115.       all_tensor_names = {output.name for op in ops for output in op.outputs}
  116.       tensor_dict = {}
  117.       for key in [
  118.           'num_detections', 'detection_boxes', 'detection_scores',
  119.           'detection_classes', 'detection_masks'
  120.       ]:
  121.         tensor_name = key + ':0'
  122.         if tensor_name in all_tensor_names:
  123.           tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
  124.               tensor_name)
  125.       if 'detection_masks' in tensor_dict:
  126.         # The following processing is only for single image
  127.         detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  128.         detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  129.         # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  130.         real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  131.         detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  132.         detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  133.         detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  134.             detection_masks, detection_boxes, image.shape[1], image.shape[2])
  135.         detection_masks_reframed = tf.cast(
  136.             tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  137.         # Follow the convention by adding back the batch dimension
  138.         tensor_dict['detection_masks'] = tf.expand_dims(
  139.             detection_masks_reframed, 0)
  140.       image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
  141.  
  142.       # Run inference
  143.       output_dict = sess.run(tensor_dict,
  144.                              feed_dict={image_tensor: image})
  145.  
  146.       # all outputs are float32 numpy arrays, so convert types as appropriate
  147.       output_dict['num_detections'] = int(output_dict['num_detections'][0])
  148.       output_dict['detection_classes'] = output_dict[
  149.           'detection_classes'][0].astype(np.int64)
  150.       output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  151.       output_dict['detection_scores'] = output_dict['detection_scores'][0]
  152.       if 'detection_masks' in output_dict:
  153.         output_dict['detection_masks'] = output_dict['detection_masks'][0]
  154.   return output_dict
  155.  
  156. def pure_pil_alpha_to_color_v2(image, color=(255, 255, 255)):
  157.     """Alpha composite an RGBA Image with a specified color.
  158.  
  159.    Source: http://stackoverflow.com/a/9459208/284318
  160.  
  161.    Keyword Arguments:
  162.    image -- PIL RGBA Image object
  163.    color -- Tuple r, g, b (default 255, 255, 255)
  164.  
  165.    """
  166.     image.load()  # needed for split()
  167.     background = Image.new('RGB', image.size, color)
  168.     background.paste(image, mask=image.split()[3])  # 3 is the alpha channel
  169.     return background
  170.  
  171. def get_bounding_box_depth_data(output_dict, image_depth_data, image_width, image_height):
  172.   # matches the detection threshold used
  173.   # in visualize_boxes_and_labels_on_image_array
  174.   detection_score_threshold = 0.5
  175.   # each bounding_box_depth_data value will be the corresponding
  176.   # depth data for a given detection box
  177.   bounding_box_depth_data = []
  178.   # each object detection box is an array whose values are in the form
  179.   # (top-left-x, top-left-y, bottom-right-x, bottom-right-y)
  180.   # each value is a range from 0 to 1
  181.   # e.g. top-left-x = 0.5, image_width = 300px => pixel value is 150px
  182.   detection_boxes = output_dict['detection_boxes']
  183.   detection_scores = output_dict['detection_scores']
  184.   for i in range(0, len(detection_boxes)):
  185.     if detection_scores[i] <= detection_score_threshold:
  186.       i = len(detection_boxes)
  187.       break
  188.     bounding_box_depth_data.append([])
  189.     detection_box = detection_boxes[i]
  190.     top_left_x = int(detection_box[0] * image_width)
  191.     top_left_y = int(detection_box[1] * image_height)
  192.     bottom_right_x = int(detection_box[2] * image_width)
  193.     bottom_right_y = int(detection_box[3] * image_height)
  194.     for y in range(top_left_y, bottom_right_y):
  195.       for x in range(top_left_x, bottom_right_x):
  196.         depth_value_index = (0 - x) + (y * image_width)
  197.         if depth_value_index >= 0 and depth_value_index < len(image_depth_data):
  198.           # these dicts are formatted to match the object schema used
  199.           # to render the d3.js heatmap in Logging/src/views/takeDepthPicture.vue
  200.           value = image_depth_data[depth_value_index]
  201.           bounding_box_depth_data[i].append({
  202.             x: x,
  203.             y: y,
  204.             value: value
  205.           })
  206.  
  207.   return bounding_box_depth_data
  208.  
  209.  
  210. def get_bounding_box_centroid_depth_data(output_dict, image_depth_data, image_width, image_height):
  211.   detection_score_threshold = 0.5
  212.   bounding_box_centroid_depth_data = []
  213.   detection_boxes = output_dict['detection_boxes']
  214.   detection_scores = output_dict['detection_scores']
  215.   for i in range(0, len(detection_boxes)):
  216.     if detection_scores[i] <= detection_score_threshold:
  217.       i = len(detection_boxes)
  218.       break
  219.     centroid_depth = "NaN"
  220.     detection_box = detection_boxes[i]
  221.     top_left_x = int(detection_box[0] * image_width)
  222.     top_left_y = int(detection_box[1] * image_height)
  223.     bottom_right_x = int(detection_box[2] * image_width)
  224.     bottom_right_y = int(detection_box[3] * image_height)
  225.     offset_direction = 0
  226.     offset_magnitude = -1
  227.     while centroid_depth == "NaN":
  228.       centroid_x = (bottom_right_x - top_left_x) / 2
  229.       centroid_y = (bottom_right_y - top_left_y) / 2
  230.       offset_direction += 1
  231.       offset_magnitude += 1
  232.       if offset_direction > 4:
  233.         offset_direction = 1
  234.       if offset_direction == 1:
  235.         # top
  236.         centroid_y += offset_magnitude
  237.       if offset_direction == 2:
  238.         # right
  239.         centroid_x += offset_magnitude
  240.       if offset_direction == 3:
  241.         # bottom
  242.         centroid_y -= offset_magnitude
  243.       if offset_direction == 4:
  244.         # left
  245.         centroid_x -= offset_magnitude
  246.       centroid_index = (0 - centroid_x) + (centroid_y * image_width)
  247.       # the centroid is outside the depth data
  248.       if centroid_index >= 0 and centroid_index < len(image_depth_data):
  249.         centroid_depth = -1
  250.       else:
  251.         centroid_depth = image_depth_data[centroid_index]
  252.     bounding_box_centroid_depth_data.append(centroid_depth)
  253.  
  254.   return bounding_box_centroid_depth_data
  255.  
  256. def get_turn_direction(detection_box, image_width):
  257.   turn_direction = "none"
  258.   threshold_offset = 10
  259.   turn_threshold_min = int(image_width / 2) - threshold_offset
  260.   turn_threshold_max = int(image_width / 2) + threshold_offset
  261.  
  262.   top_left_x = int(detection_box[0] * image_width)
  263.   bottom_right_x = int(detection_box[2] * image_width)
  264.   x_center = bottom_right_x - top_left_x
  265.   if x_center < turn_threshold_min:
  266.     turn_direction = "left"
  267.   if x_center > turn_threshold_max:
  268.     turn_direction = "right"
  269.  
  270.   return turn_direction
  271.  
  272.  
  273. # parses a base64 image to a pil image
  274. # then runs object detection on that image
  275. # if it is a fisheye image derive depth data
  276. # using the object detection boxes
  277. # and returns the processed image as well as
  278. # any relavent depth or bounding data
  279. async def process_image(websocket, message):
  280.   print("processing image . . . ")
  281.   parsed_message = json.loads(message)
  282.   image_base_64 = parsed_message['image']
  283.   image_height = parsed_message['image_height']
  284.   image_width = parsed_message['image_width']
  285.   image_depth_data = parsed_message['image_depth_data']
  286.   is_fisheye_image = parsed_message['is_fisheye_image']
  287.   get_depth_data = parsed_message['get_depth_data']
  288.  
  289.   img = Image.open(BytesIO(base64.b64decode(image_base_64)))
  290.   if is_fisheye_image:
  291.     # remove the alpha channel, fisheye images are in rgba
  292.     # and tensorflow just can't deal with that
  293.     img = pure_pil_alpha_to_color_v2(img)
  294.  
  295.   # the array based representation of the image will be used later in order to prepare the
  296.   # result image with boxes and labels on it.
  297.   image_np = load_image_into_numpy_array(img)
  298.   # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  299.   image_np_expanded = np.expand_dims(image_np, axis=0)
  300.   # Actual detection.
  301.   output_dict = run_inference_for_single_image(image_np_expanded, detection_graph)
  302.   # Visualization of the results of a detection.
  303.   vis_util.visualize_boxes_and_labels_on_image_array(
  304.       image_np,
  305.       output_dict['detection_boxes'],
  306.       output_dict['detection_classes'],
  307.       output_dict['detection_scores'],
  308.       category_index,
  309.       instance_masks=output_dict.get('detection_masks'),
  310.       use_normalized_coordinates=True,
  311.       line_thickness=8)
  312.   pil_img = Image.fromarray(image_np)
  313.   buff = BytesIO()
  314.   pil_img.save(buff, format="JPEG")
  315.   processed_image_base_64 = base64.b64encode(buff.getvalue()).decode("utf-8")
  316.  
  317.   has_detections = len(output_dict['detection_boxes']) > 0
  318.   # detection data is only relavent to fisheye images
  319.   # using detection data with rgba images will result in poor depth detection
  320.   # because the fisheye lens distorts the proportions of the image
  321.   bounding_box_depth_data = None
  322.   bounding_box_centroid_depth_data = None
  323.   turn_direction = None
  324.   if is_fisheye_image and has_detections:
  325.     if get_depth_data:
  326.       # retrieve the associated depth data for each bounding box
  327.       # useful for debugging
  328.       print("2")
  329.       print(image_depth_data==None)
  330.       bounding_box_depth_data = get_bounding_box_depth_data(output_dict, image_depth_data, image_width, image_height)
  331.     # get the depth value of the centroid of each bounding box
  332.     # if the exact centroid is 'NaN' search in each direction and
  333.     # return the first non-NaN value. If the centroid is outside the
  334.     # depth data bounds, return -1
  335.     print("1")
  336.     print(image_depth_data==None)
  337.     bounding_box_centroid_depth_data = get_bounding_box_centroid_depth_data(output_dict, image_depth_data, image_width, image_height)
  338.     # get a turn direction based off of the nearest detection box
  339.     nearest_detection_index = 0
  340.     nearest_detection_value = -1
  341.     for i in range(0, len(bounding_box_centroid_depth_data)):
  342.       if bounding_box_centroid_depth_data[i] > nearest_detection_value:
  343.         nearest_detection_index = i
  344.         nearest_detection_value = bounding_box_centroid_depth_data[i]
  345.     turn_direction = get_turn_direction(output_dict['detection_boxes'][nearest_detection_index], image_width)
  346.  
  347.   return_data = {}
  348.   return_data['image_depth_data'] = bounding_box_depth_data
  349.   return_data['image_centroid_depth_data'] = bounding_box_centroid_depth_data
  350.   return_data['processed_image'] = processed_image_base_64
  351.   return_data['turn_direction'] = turn_direction
  352.   print("image successfully processed")
  353.   await websocket.send(json.dumps(return_data))
  354.  
  355. # Websocket logic
  356. # https://websockets.readthedocs.io/en/stable/intro.html
  357.  
  358. async def consumer_handler(websocket, path):
  359.   async for message in websocket:
  360.     print('recvd messg')
  361.     await process_image(websocket, message)
  362.  
  363. start_server = websockets.serve(consumer_handler, "localhost", 8765)
  364. print('Websocket listening on port 8765!')
  365.  
  366. asyncio.get_event_loop().run_until_complete(start_server)
  367. asyncio.get_event_loop().run_forever()
Add Comment
Please, Sign In to add comment