Advertisement
Guest User

Untitled

a guest
Jul 29th, 2018
218
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.02 KB | None | 0 0
  1.  
  2.  
  3.     # USAGE
  4.     # import the necessary packages
  5.     from mvnc import mvncapi as mvnc
  6.     from imutils.video import VideoStream
  7.     from imutils.video import FPS
  8.     import argparse
  9.     import numpy as np
  10.     import time
  11.     import cv2
  12.     import traceback
  13.    
  14.     # initialize the list of class labels our network was trained to
  15.     # detect, then generate a set of bounding box colors for each class
  16.     CLASSES = ("background", "aeroplane", "bicycle", "bird",
  17.         "boat", "bottle", "bus", "car", "cat", "chair", "cow",
  18.         "diningtable", "dog", "horse", "motorbike", "person",
  19.         "pottedplant", "sheep", "sofa", "train", "tvmonitor")
  20.    
  21.     CLASSES = []
  22.     label_map_path = 'data/coco/labelmap_coco.prototxt'
  23.    
  24.     with open(label_map_path) as f:
  25.         lines = f.readlines()
  26.         for x in range(3, len(lines),5):
  27.             CLASSES.append(((lines[x].split(": "))[1]).replace("\"","").replace("\n",""))
  28.     print (CLASSES)
  29.    
  30.     COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
  31.    
  32.     # frame dimensions should be sqaure
  33.     PREPROCESS_DIMS = (300, 300)
  34.     DISPLAY_DIMS = (900, 900)
  35.    
  36.     # calculate the multiplier needed to scale the bounding boxes
  37.     DISP_MULTIPLIER = DISPLAY_DIMS[0] // PREPROCESS_DIMS[0]
  38.    
  39.     def preprocess_image(input_image):
  40.         # preprocess the image
  41.         preprocessed = cv2.resize(input_image, PREPROCESS_DIMS)
  42.         preprocessed = preprocessed - 127.5
  43.         preprocessed = preprocessed * 0.007843
  44.         preprocessed = preprocessed.astype(np.float32)
  45.    
  46.         # return the image to the calling function
  47.         return preprocessed
  48.    
  49.     def predict(image, graph,input_fifo, output_fifo):
  50.         # preprocess the image
  51.         #image = preprocess_image(image)
  52.    
  53.         # send the image to the NCS and run a forward pass to grab the
  54.         # network predictions
  55.        
  56.         image = preprocess_image(image)
  57.         print("start queue inference")
  58.         graph.queue_inference_with_fifo_elem (input_fifo, output_fifo, image, image)
  59.         print("success queue inference")
  60.         (output, _) = output_fifo.read_elem()
  61.         print("success read_element")
  62.         # grab the number of valid object predictions from the output,
  63.         # then initialize the list of predictions
  64.         num_valid_boxes = output[0]
  65.         #print ('output', output)
  66.         predictions = []
  67.         #print("num_valid_boxes",num_valid_boxes)
  68.         # loop over results
  69.         for box_index in range(int(num_valid_boxes)):
  70.             # calculate the base index into our array so we can extract
  71.             # bounding box information
  72.             base_index = 7 + box_index * 7
  73.    
  74.             # boxes with non-finite (inf, nan, etc) numbers must be ignored
  75.             if (not np.isfinite(output[base_index]) or
  76.                 not np.isfinite(output[base_index + 1]) or
  77.                 not np.isfinite(output[base_index + 2]) or
  78.                 not np.isfinite(output[base_index + 3]) or
  79.                 not np.isfinite(output[base_index + 4]) or
  80.                 not np.isfinite(output[base_index + 5]) or
  81.                 not np.isfinite(output[base_index + 6])):
  82.                 continue
  83.    
  84.             # extract the image width and height and clip the boxes to the
  85.             # image size in case network returns boxes outside of the image
  86.             # boundaries
  87.             (h, w) = image.shape[:2]
  88.             x1 = max(0, int(output[base_index + 3] * w))
  89.             y1 = max(0, int(output[base_index + 4] * h))
  90.             x2 = min(w, int(output[base_index + 5] * w))
  91.             y2 = min(h, int(output[base_index + 6] * h))
  92.    
  93.             # grab the prediction class label, confidence (i.e., probability),
  94.             # and bounding box (x, y)-coordinates
  95.             print("pred class",output[base_index + 1])
  96.             pred_class = int(output[base_index + 1])
  97.             pred_conf = output[base_index + 2]
  98.             pred_boxpts = ((x1, y1), (x2, y2))
  99.    
  100.             # create prediciton tuple and append the prediction to the
  101.             # predictions list
  102.             prediction = (pred_class, pred_conf, pred_boxpts)
  103.             predictions.append(prediction)
  104.    
  105.         # return the list of predictions to the calling function
  106.         return predictions
  107.    
  108.     # construct the argument parser and parse the arguments
  109.     ap = argparse.ArgumentParser()
  110.     ap.add_argument("-g", "--graph", default="graphs/vggnetgraph",
  111.         help="path to input graph file")
  112.     ap.add_argument("-c", "--confidence", default=.5,
  113.         help="confidence threshold")
  114.     ap.add_argument("-d", "--display", type=int, default=1,
  115.         help="switch to display image on screen")
  116.     args = vars(ap.parse_args())
  117.    
  118.     # grab a list of all NCS devices plugged in to USB
  119.     print("[INFO] finding NCS devices...")
  120.     devices = mvnc.enumerate_devices()
  121.    
  122.     # if no devices found, exit the script
  123.     if len(devices) == 0:
  124.         print("[INFO] No devices found. Please plug in a NCS")
  125.         quit()
  126.    
  127.     # use the first device since this is a simple test script
  128.     # (you'll want to modify this is using multiple NCS devices)
  129.     print("[INFO] found {} devices. device0 will be used. "
  130.         "opening device0...".format(len(devices)))
  131.     device = mvnc.Device(devices[0])
  132.     device.open()
  133.    
  134.     # open the CNN graph file
  135.     print("[INFO] loading the graph file into RPi memory...")
  136.     with open(args["graph"], mode="rb") as f:
  137.         graph_in_memory = f.read()
  138.    
  139.     # load the graph into the NCS
  140.     print("[INFO] allocating the graph on the NCS...")
  141.     graph = mvnc.Graph('graph1')
  142.     graph.allocate(device,graph_in_memory)
  143.    
  144.     # open a pointer to the video stream thread and allow the buffer to
  145.     # start to fill, then start the FPS counter
  146.     print("[INFO] starting the video stream and FPS counter...")
  147.     vs = VideoStream(usePiCamera=True).start()
  148.     time.sleep(1)
  149.     fps = FPS().start()
  150.     print("[INFO] creating fifo ")
  151.     input_fifo, output_fifo = graph.allocate_with_fifos(device, graph_in_memory)
  152.     print("[INFO] fifo created")
  153.     # loop over frames from the video file stream
  154.     while True:
  155.         try:
  156.             # grab the frame from the threaded video stream
  157.             # make a copy of the frame and resize it for display/video purposes
  158.             frame = vs.read()
  159.             image_for_result = frame.copy()
  160.             image_for_result = cv2.resize(image_for_result, DISPLAY_DIMS)
  161.             image_for_result = cv2.flip(image_for_result,0)
  162.             # use the NCS to acquire predictions
  163.             print("[INFO] start predict")
  164.             predictions = predict(frame, graph,input_fifo,output_fifo)
  165.             print("[INFO] success predict")
  166.             # loop over our predictions
  167.             for (i, pred) in enumerate(predictions):
  168.                 # extract prediction data for readability
  169.                 (pred_class, pred_conf, pred_boxpts) = pred
  170.    
  171.                 # filter out weak detections by ensuring the `confidence`
  172.                 # is greater than the minimum confidence
  173.                 if pred_conf > args["confidence"]:
  174.                     # print prediction to terminal
  175.                     print("[INFO] Prediction #{}: class={},classID={}, confidence={}, "
  176.                         "boxpoints={}".format(i, CLASSES[pred_class],pred_class, pred_conf,
  177.                         pred_boxpts))
  178.    
  179.                     # check if we should show the prediction data
  180.                     # on the frame
  181.                     if args["display"] > 0:
  182.                         # build a label consisting of the predicted class and
  183.                         # associated probability
  184.                         label = "{}: {:.2f}%".format(CLASSES[pred_class],
  185.                             pred_conf * 100)
  186.    
  187.                         # extract information from the prediction boxpoints
  188.                         (ptA, ptB) = (pred_boxpts[0], pred_boxpts[1])
  189.                         ptA = (ptA[0] * DISP_MULTIPLIER, ptA[1] * DISP_MULTIPLIER)
  190.                         ptB = (ptB[0] * DISP_MULTIPLIER, ptB[1] * DISP_MULTIPLIER)
  191.                         (startX, startY) = (ptA[0], ptA[1])
  192.                         y = startY - 15 if startY - 15 > 15 else startY + 15
  193.    
  194.                         # display the rectangle and label text
  195.                         cv2.rectangle(image_for_result, ptA, ptB,
  196.                             COLORS[pred_class], 2)
  197.                         cv2.putText(image_for_result, label, (startX, y),
  198.                             cv2.FONT_HERSHEY_SIMPLEX, 1, COLORS[pred_class], 3)
  199.    
  200.             # check if we should display the frame on the screen
  201.             # with prediction data (you can achieve faster FPS if you
  202.             # do not output to the screen)
  203.             if args["display"] > 0:
  204.                 # display the frame to the screen
  205.                 cv2.imshow("Output", image_for_result)
  206.                 key = cv2.waitKey(1) & 0xFF
  207.    
  208.                 # if the `q` key was pressed, break from the loop
  209.                 if key == ord("q"):
  210.                     break
  211.    
  212.             # update the FPS counter
  213.             fps.update()
  214.        
  215.         # if "ctrl+c" is pressed in the terminal, break from the loop
  216.         except KeyboardInterrupt:
  217.             break
  218.    
  219.         # if there's a problem reading a frame, break gracefully
  220.         except Exception as e:
  221.             #break
  222.             print(e)
  223.             traceback.print_exc()
  224.    
  225.     # stop the FPS counter timer
  226.     fps.stop()
  227.    
  228.     # destroy all windows if we are displaying them
  229.     if args["display"] > 0:
  230.         cv2.destroyAllWindows()
  231.    
  232.     # stop the video stream
  233.     vs.stop()
  234.    
  235.     # clean up the graph and device
  236.     graph.destroy()
  237.     device.close()
  238.    
  239.     # display FPS information
  240.     print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
  241.     print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement