Advertisement
Guest User

Untitled

a guest
Nov 18th, 2019
136
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.89 KB | None | 0 0
  1. #!/usr/bin/env python
  2. """
  3. Copyright (C) 2018-2019 Intel Corporation
  4.  
  5. Licensed under the Apache License, Version 2.0 (the "License");
  6. you may not use this file except in compliance with the License.
  7. You may obtain a copy of the License at
  8.  
  9.      http://www.apache.org/licenses/LICENSE-2.0
  10.  
  11. Unless required by applicable law or agreed to in writing, software
  12. distributed under the License is distributed on an "AS IS" BASIS,
  13. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. See the License for the specific language governing permissions and
  15. limitations under the License.
  16. """
  17.  
  18. from __future__ import print_function
  19. import sys
  20. import os
  21. from argparse import ArgumentParser, SUPPRESS
  22. import cv2
  23. import time
  24. import logging as log
  25.  
  26. from openvino.inference_engine import IENetwork, IECore
  27.  
  28.  
  29. def build_argparser():
  30.     parser = ArgumentParser(add_help=False)
  31.     args = parser.add_argument_group('Options')
  32.     args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
  33.     args.add_argument("-m", "--model", help="Required. Path to an .xml file with a trained model.",
  34.                       required=True, type=str)
  35.     args.add_argument("-i", "--input",
  36.                       help="Required. Path to video file or image. 'cam' for capturing video stream from camera",
  37.                       required=True, type=str)
  38.     args.add_argument("-l", "--cpu_extension",
  39.                       help="Optional. Required for CPU custom layers. Absolute path to a shared library with the "
  40.                            "kernels implementations.", type=str, default=None)
  41.     args.add_argument("-d", "--device",
  42.                       help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL or MYRIAD is "
  43.                            "acceptable. The demo will look for a suitable plugin for device specified. "
  44.                            "Default value is CPU", default="CPU", type=str)
  45.     args.add_argument("--labels", help="Optional. Path to labels mapping file", default=None, type=str)
  46.     args.add_argument("-pt", "--prob_threshold", help="Optional. Probability threshold for detections filtering",
  47.                       default=0.5, type=float)
  48.  
  49.     return parser
  50.  
  51.  
  52. def main():
  53.     log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
  54.     args = build_argparser().parse_args()
  55.     model_xml = args.model
  56.     model_bin = os.path.splitext(model_xml)[0] + ".bin"
  57.  
  58.     log.info("Creating Inference Engine...")
  59.     ie = IECore()
  60.     if args.cpu_extension and 'CPU' in args.device:
  61.         ie.add_extension(args.cpu_extension, "CPU")
  62.     # Read IR
  63.     log.info("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin))
  64.     net = IENetwork(model=model_xml, weights=model_bin)
  65.  
  66.     if "CPU" in args.device:
  67.         supported_layers = ie.query_network(net, "CPU")
  68.         not_supported_layers = [l for l in net.layers.keys() if l not in supported_layers]
  69.         if len(not_supported_layers) != 0:
  70.             log.error("Following layers are not supported by the plugin for specified device {}:\n {}".
  71.                       format(args.device, ', '.join(not_supported_layers)))
  72.             log.error("Please try to specify cpu extensions library path in sample's command line parameters using -l "
  73.                       "or --cpu_extension command line argument")
  74.             sys.exit(1)
  75.  
  76.     img_info_input_blob = None
  77.     feed_dict = {}
  78.     for blob_name in net.inputs:
  79.         if len(net.inputs[blob_name].shape) == 4:
  80.             input_blob = blob_name
  81.         elif len(net.inputs[blob_name].shape) == 2:
  82.             img_info_input_blob = blob_name
  83.         else:
  84.             raise RuntimeError("Unsupported {}D input layer '{}'. Only 2D and 4D input layers are supported"
  85.                                .format(len(net.inputs[blob_name].shape), blob_name))
  86.  
  87.     assert len(net.outputs) == 1, "Demo supports only single output topologies"
  88.  
  89.     out_blob = next(iter(net.outputs))
  90.     log.info("Loading IR to the plugin...")
  91.     exec_net = ie.load_network(network=net, num_requests=2, device_name=args.device)
  92.     # Read and pre-process input image
  93.     n, c, h, w = net.inputs[input_blob].shape
  94.     if img_info_input_blob:
  95.         feed_dict[img_info_input_blob] = [h, w, 1]
  96.  
  97.     if args.input == 'cam':
  98.         input_stream = 0
  99.     else:
  100.         input_stream = args.input
  101.         assert os.path.isfile(args.input), "Specified input file doesn't exist"
  102.     if args.labels:
  103.         with open(args.labels, 'r') as f:
  104.             labels_map = [x.strip() for x in f]
  105.     else:
  106.         labels_map = None
  107.  
  108.     cap = cv2.VideoCapture(input_stream)
  109.  
  110.     cur_request_id = 0
  111.     next_request_id = 1
  112.  
  113.     log.info("Starting inference in async mode...")
  114.     is_async_mode = True
  115.     render_time = 0
  116.     ret, frame = cap.read()
  117.  
  118.     print("To close the application, press 'CTRL+C' here or switch to the output window and press ESC key")
  119.     print("To switch between sync/async modes, press TAB key in the output window")
  120.  
  121.     while cap.isOpened():
  122.         if is_async_mode:
  123.             ret, next_frame = cap.read()
  124.         else:
  125.             ret, frame = cap.read()
  126.         if not ret:
  127.             break
  128.         initial_w = cap.get(3)
  129.         initial_h = cap.get(4)
  130.         # Main sync point:
  131.         # in the truly Async mode we start the NEXT infer request, while waiting for the CURRENT to complete
  132.         # in the regular mode we start the CURRENT request and immediately wait for it's completion
  133.         inf_start = time.time()
  134.         if is_async_mode:
  135.             in_frame = cv2.resize(next_frame, (w, h))
  136.             in_frame = in_frame.transpose((2, 0, 1))  # Change data layout from HWC to CHW
  137.             in_frame = in_frame.reshape((n, c, h, w))
  138.             feed_dict[input_blob] = in_frame
  139.             exec_net.start_async(request_id=next_request_id, inputs=feed_dict)
  140.         else:
  141.             in_frame = cv2.resize(frame, (w, h))
  142.             in_frame = in_frame.transpose((2, 0, 1))  # Change data layout from HWC to CHW
  143.             in_frame = in_frame.reshape((n, c, h, w))
  144.             feed_dict[input_blob] = in_frame
  145.             exec_net.start_async(request_id=cur_request_id, inputs=feed_dict)
  146.         if exec_net.requests[cur_request_id].wait(-1) == 0:
  147.             inf_end = time.time()
  148.             det_time = inf_end - inf_start
  149.  
  150.             # Parse detection results of the current request
  151.             res = exec_net.requests[cur_request_id].outputs[out_blob]
  152.             for obj in res[0][0]:
  153.                 # Draw only objects when probability more than specified threshold
  154.                 if obj[2] > args.prob_threshold:
  155.                     xmin = int(obj[3] * initial_w)
  156.                     ymin = int(obj[4] * initial_h)
  157.                     xmax = int(obj[5] * initial_w)
  158.                     ymax = int(obj[6] * initial_h)
  159.                     class_id = int(obj[1])
  160.                     # Draw box and label\class_id
  161.                     color = (min(class_id * 12.5, 255), min(class_id * 7, 255), min(class_id * 5, 255))
  162.                     cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), color, 2)
  163.                     det_label = labels_map[class_id] if labels_map else str(class_id)
  164.                     cv2.putText(frame, det_label + ' ' + str(round(obj[2] * 100, 1)) + ' %', (xmin, ymin - 7),
  165.                                 cv2.FONT_HERSHEY_COMPLEX, 0.6, color, 1)
  166.  
  167.             # Draw performance stats
  168.             inf_time_message = "Inference time: N\A for async mode" if is_async_mode else \
  169.                 "Inference time: {:.3f} ms".format(det_time * 1000)
  170.             render_time_message = "OpenCV rendering time: {:.3f} ms".format(render_time * 1000)
  171.             async_mode_message = "Async mode is on. Processing request {}".format(cur_request_id) if is_async_mode else \
  172.                 "Async mode is off. Processing request {}".format(cur_request_id)
  173.  
  174.             cv2.putText(frame, inf_time_message, (15, 15), cv2.FONT_HERSHEY_COMPLEX, 0.5, (200, 10, 10), 1)
  175.             cv2.putText(frame, render_time_message, (15, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (10, 10, 200), 1)
  176.             cv2.putText(frame, async_mode_message, (10, int(initial_h - 20)), cv2.FONT_HERSHEY_COMPLEX, 0.5,
  177.                         (10, 10, 200), 1)
  178.  
  179.         #
  180.         render_start = time.time()
  181.         cv2.imshow("Detection Results", frame)
  182.         render_end = time.time()
  183.         render_time = render_end - render_start
  184.  
  185.         if is_async_mode:
  186.             cur_request_id, next_request_id = next_request_id, cur_request_id
  187.             frame = next_frame
  188.  
  189.         key = cv2.waitKey(1)
  190.         if key == 27:
  191.             break
  192.         if (9 == key):
  193.             is_async_mode = not is_async_mode
  194.             log.info("Switched to {} mode".format("async" if is_async_mode else "sync"))
  195.  
  196.     cv2.destroyAllWindows()
  197.  
  198.  
  199. if __name__ == '__main__':
  200.     sys.exit(main() or 0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement