Advertisement
Guest User

Untitled

a guest
Sep 15th, 2018
132
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.85 KB | None | 0 0
  1.  
  2. # USAGE
  3. # python ncs_realtime_objectdetection.py --graph graphs/mobilenetgraph --display 1
  4. # python ncs_realtime_objectdetection.py --graph graphs/mobilenetgraph --confidence 0.5 --display 1
  5.  
  6. # import the necessary packages
  7. from mvnc import mvncapi as mvnc
  8. from imutils.video import VideoStream
  9. from imutils.video import FPS
  10. import argparse
  11. import numpy as np
  12. import time
  13. import cv2
  14. import traceback
  15.  
  16. # initialize the list of class labels our network was trained to
  17. # detect, then generate a set of bounding box colors for each class
  18. CLASSES = ("background", "aeroplane", "bicycle", "bird",
  19.     "boat", "bottle", "bus", "car", "cat", "chair", "cow",
  20.     "diningtable", "dog", "horse", "motorbike", "person",
  21.     "pottedplant", "sheep", "sofa", "train", "tvmonitor")
  22. COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
  23. TRACKING_CLASS = (0, 15)
  24. # frame dimensions should be sqaure
  25. PREPROCESS_DIMS = (300, 300)
  26. DISPLAY_DIMS = (900, 900)
  27. DEFAULT_CONF_TRES = .2
  28. # calculate the multiplier needed to scale the bounding boxes
  29. DISP_MULTIPLIER = DISPLAY_DIMS[0] // PREPROCESS_DIMS[0]
  30.  
  31. def preprocess_image(input_image):
  32.     # preprocess the image
  33.     preprocessed = cv2.resize(input_image, PREPROCESS_DIMS)
  34.     preprocessed = preprocessed - 127.5
  35.     preprocessed = preprocessed * 0.007843
  36.     preprocessed = preprocessed.astype(np.float32)
  37.  
  38.     # return the image to the calling function
  39.     return preprocessed
  40.  
  41. def predict(image, graph,input_fifo, output_fifo):
  42.     # preprocess the image
  43.     #image = preprocess_image(image)
  44.  
  45.     # send the image to the NCS and run a forward pass to grab the
  46.     # network predictions
  47.     image = preprocess_image(image)
  48.     graph.queue_inference_with_fifo_elem (input_fifo, output_fifo, image, image)
  49.    
  50.     (output, _) = output_fifo.read_elem()
  51.     # grab the number of valid object predictions from the output,
  52.     # then initialize the list of predictions
  53.     num_valid_boxes = output[0]
  54.     predictions = []
  55.     print("num_valid_boxes",num_valid_boxes)
  56.     # loop over results
  57.     for box_index in range(int(num_valid_boxes)):
  58.         # calculate the base index into our array so we can extract
  59.         # bounding box information
  60.         base_index = 7 + box_index * 7
  61.  
  62.         # boxes with non-finite (inf, nan, etc) numbers must be ignored
  63.         if (not np.isfinite(output[base_index]) or
  64.             not np.isfinite(output[base_index + 1]) or
  65.             not np.isfinite(output[base_index + 2]) or
  66.             not np.isfinite(output[base_index + 3]) or
  67.             not np.isfinite(output[base_index + 4]) or
  68.             not np.isfinite(output[base_index + 5]) or
  69.             not np.isfinite(output[base_index + 6])):
  70.             continue
  71.  
  72.         # extract the image width and height and clip the boxes to the
  73.         # image size in case network returns boxes outside of the image
  74.         # boundaries
  75.         (h, w) = image.shape[:2]
  76.         x1 = max(0, int(output[base_index + 3] * w))
  77.         y1 = max(0, int(output[base_index + 4] * h))
  78.         x2 = min(w, int(output[base_index + 5] * w))
  79.         y2 = min(h, int(output[base_index + 6] * h))
  80.  
  81.         # grab the prediction class label, confidence (i.e., probability),
  82.         # and bounding box (x, y)-coordinates
  83.         pred_class = int(output[base_index + 1])
  84.         pred_conf = output[base_index + 2]
  85.         pred_boxpts = ((x1, y1), (x2, y2))
  86.  
  87.         # create prediciton tuple and append the prediction to the
  88.         # predictions list
  89.         prediction = (pred_class, pred_conf, pred_boxpts)
  90.         predictions.append(prediction)
  91.  
  92.     # return the list of predictions to the calling function
  93.     return predictions
  94.  
  95. # construct the argument parser and parse the arguments
  96. ap = argparse.ArgumentParser()
  97. ap.add_argument("-g", "--graph", default="graphs/mobilenetgraph",
  98.     help="path to input graph file")
  99. ap.add_argument("-c", "--confidence", default=DEFAULT_CONF_TRES,
  100.     help="confidence threshold")
  101. ap.add_argument("-d", "--display", type=int, default=1,
  102.     help="switch to display image on screen")
  103. args = vars(ap.parse_args())
  104.  
  105. # grab a list of all NCS devices plugged in to USB
  106. print("[INFO] finding NCS devices...")
  107. devices = mvnc.enumerate_devices()
  108.  
  109. # if no devices found, exit the script
  110. if len(devices) == 0:
  111.     print("[INFO] No devices found. Please plug in a NCS")
  112.     quit()
  113.  
  114. # use the first device since this is a simple test script
  115. # (you'll want to modify this is using multiple NCS devices)
  116. print("[INFO] found {} devices. device0 will be used. "
  117.     "opening device0...".format(len(devices)))
  118. device = mvnc.Device(devices[0])
  119. device.open()
  120.  
  121. # open the CNN graph file
  122. print("[INFO] loading the graph file into RPi memory...")
  123. with open(args["graph"], mode="rb") as f:
  124.     graph_in_memory = f.read()
  125.  
  126. # load the graph into the NCS
  127. print("[INFO] allocating the graph on the NCS...")
  128. graph = mvnc.Graph('graph1')
  129. graph.allocate(device,graph_in_memory)
  130.  
  131. # open a pointer to the video stream thread and allow the buffer to
  132. # start to fill, then start the FPS counter
  133. print("[INFO] starting the video stream and FPS counter...")
  134. vs = VideoStream(usePiCamera=True).start()
  135. time.sleep(1)
  136. fps = FPS().start()
  137. print("[INFO] creating fifo ")
  138. input_fifo, output_fifo = graph.allocate_with_fifos(device, graph_in_memory)
  139. print("[INFO] fifo created")
  140. # loop over frames from the video file stream
  141. while True:
  142.     try:
  143.         # grab the frame from the threaded video stream
  144.         # make a copy of the frame and resize it for display/video purposes
  145.         frame = vs.read()
  146.         image_for_result = frame.copy()
  147.         image_for_result = cv2.resize(image_for_result, DISPLAY_DIMS)
  148.         image_for_result = cv2.flip(image_for_result,0)
  149.         # use the NCS to acquire predictions
  150.         predictions = predict(frame, graph,input_fifo,output_fifo)
  151.  
  152.         # loop over our predictions
  153.         for (i, pred) in enumerate(predictions):
  154.             # extract prediction data for readability
  155.             (pred_class, pred_conf, pred_boxpts) = pred
  156.  
  157.             # filter out weak detections by ensuring the `confidence`
  158.             # is greater than the minimum confidence
  159.             if  ( pred_class in TRACKING_CLASS) and pred_conf > args["confidence"]:
  160.                 # print prediction to terminal
  161.                 print("[INFO] Prediction #{}: class={}, confidence={}, "
  162.                     "boxpoints={}".format(i, CLASSES[pred_class], pred_conf,
  163.                     pred_boxpts))
  164.  
  165.                 # check if we should show the prediction data
  166.                 # on the frame
  167.                 if args["display"] > 0:
  168.                     # build a label consisting of the predicted class and
  169.                     # associated probability
  170.                     label = "{}: {:.2f}%".format(CLASSES[pred_class],
  171.                         pred_conf * 100)
  172.  
  173.                     # extract information from the prediction boxpoints
  174.                     (ptA, ptB) = (pred_boxpts[0], pred_boxpts[1])
  175.                     ptA = (ptA[0] * DISP_MULTIPLIER, ptA[1] * DISP_MULTIPLIER)
  176.                     ptB = (ptB[0] * DISP_MULTIPLIER, ptB[1] * DISP_MULTIPLIER)
  177.                     (startX, startY) = (ptA[0], ptA[1])
  178.                     y = startY - 15 if startY - 15 > 15 else startY + 15
  179.  
  180.                     # display the rectangle and label text
  181.                     cv2.rectangle(image_for_result, ptA, ptB,
  182.                         COLORS[pred_class], 2)
  183.                     cv2.putText(image_for_result, label, (startX, y),
  184.                         cv2.FONT_HERSHEY_SIMPLEX, 1, COLORS[pred_class], 3)
  185.  
  186.         # check if we should display the frame on the screen
  187.         # with prediction data (you can achieve faster FPS if you
  188.         # do not output to the screen)
  189.         if args["display"] > 0:
  190.             # display the frame to the screen
  191.             cv2.imshow("Output", image_for_result)
  192.             key = cv2.waitKey(1) & 0xFF
  193.  
  194.             # if the `q` key was pressed, break from the loop
  195.             if key == ord("q"):
  196.                 break
  197.  
  198.         # update the FPS counter
  199.         fps.update()
  200.    
  201.     # if "ctrl+c" is pressed in the terminal, break from the loop
  202.     except KeyboardInterrupt:
  203.         break
  204.  
  205.     # if there's a problem reading a frame, break gracefully
  206.     except Exception as e:
  207.         #break
  208.         print(e)
  209.         traceback.print_exc()
  210.  
  211. # stop the FPS counter timer
  212. fps.stop()
  213.  
  214. # destroy all windows if we are displaying them
  215. if args["display"] > 0:
  216.     cv2.destroyAllWindows()
  217.  
  218. # stop the video stream
  219. vs.stop()
  220.  
  221. # clean up the graph and device
  222. graph.destroy()
  223. device.close()
  224.  
  225. # display FPS information
  226. print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
  227. print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement