Advertisement
Guest User

RPI-Movidius Custom SSD Mobilenet test

a guest
May 6th, 2018
582
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.13 KB | None | 0 0
  1. # USAGE
  2. # python ncs_realtime_objectdetection.py --graph graphs/mobilenetgraph --display 1
  3. # python ncs_realtime_objectdetection.py --graph graphs/mobilenetgraph --confidence 0.5 --display 1
  4.  
  5. # import the necessary packages
  6. from mvnc import mvncapi as mvnc
  7. from imutils.video import VideoStream
  8. from imutils.video import FPS
  9. import argparse
  10. import numpy as np
  11. import time
  12. import cv2
  13.  
  14. # initialize the list of class labels our network was trained to
  15. # detect, then generate a set of bounding box colors for each class
  16. CLASSES = ("background", "can", "bottle")
  17. COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
  18.  
  19. # frame dimensions should be sqaure
  20. PREPROCESS_DIMS = (300, 300)
  21. DISPLAY_DIMS = (900, 900)
  22.  
  23. # calculate the multiplier needed to scale the bounding boxes
  24. DISP_MULTIPLIER = DISPLAY_DIMS[0] // PREPROCESS_DIMS[0]
  25.  
  26. def preprocess_image(input_image):
  27. # preprocess the image
  28. preprocessed = cv2.resize(input_image, PREPROCESS_DIMS)
  29. preprocessed = preprocessed - 127.5
  30. preprocessed = preprocessed * 0.007843
  31. preprocessed = preprocessed.astype(np.float16)
  32.  
  33. # return the image to the calling function
  34. return preprocessed
  35.  
  36. def predict(image, graph):
  37. # preprocess the image
  38. image = preprocess_image(image)
  39.  
  40. # send the image to the NCS and run a forward pass to grab the
  41. # network predictions
  42. graph.LoadTensor(image, None)
  43. (output, _) = graph.GetResult()
  44.  
  45. # grab the number of valid object predictions from the output,
  46. # then initialize the list of predictions
  47. num_valid_boxes = output[0]
  48. predictions = []
  49.  
  50. # loop over results
  51. for box_index in range(num_valid_boxes):
  52. # calculate the base index into our array so we can extract
  53. # bounding box information
  54. base_index = 7 + box_index * 7
  55.  
  56. # boxes with non-finite (inf, nan, etc) numbers must be ignored
  57. if (not np.isfinite(output[base_index]) or
  58. not np.isfinite(output[base_index + 1]) or
  59. not np.isfinite(output[base_index + 2]) or
  60. not np.isfinite(output[base_index + 3]) or
  61. not np.isfinite(output[base_index + 4]) or
  62. not np.isfinite(output[base_index + 5]) or
  63. not np.isfinite(output[base_index + 6])):
  64. continue
  65.  
  66. # extract the image width and height and clip the boxes to the
  67. # image size in case network returns boxes outside of the image
  68. # boundaries
  69. (h, w) = image.shape[:2]
  70. x1 = max(0, int(output[base_index + 3] * w))
  71. y1 = max(0, int(output[base_index + 4] * h))
  72. x2 = min(w, int(output[base_index + 5] * w))
  73. y2 = min(h, int(output[base_index + 6] * h))
  74.  
  75. # grab the prediction class label, confidence (i.e., probability),
  76. # and bounding box (x, y)-coordinates
  77. pred_class = int(output[base_index + 1])
  78. pred_conf = output[base_index + 2]
  79. pred_boxpts = ((x1, y1), (x2, y2))
  80.  
  81. # create prediciton tuple and append the prediction to the
  82. # predictions list
  83. prediction = (pred_class, pred_conf, pred_boxpts)
  84. predictions.append(prediction)
  85.  
  86. # return the list of predictions to the calling function
  87. return predictions
  88.  
  89. # construct the argument parser and parse the arguments
  90. ap = argparse.ArgumentParser()
  91. ap.add_argument("-g", "--graph", required=True,
  92. help="path to input graph file")
  93. ap.add_argument("-c", "--confidence", default=.5,
  94. help="confidence threshold")
  95. ap.add_argument("-d", "--display", type=int, default=0,
  96. help="switch to display image on screen")
  97. args = vars(ap.parse_args())
  98.  
  99. # grab a list of all NCS devices plugged in to USB
  100. print("[INFO] finding NCS devices...")
  101. devices = mvnc.EnumerateDevices()
  102.  
  103. # if no devices found, exit the script
  104. if len(devices) == 0:
  105. print("[INFO] No devices found. Please plug in a NCS")
  106. quit()
  107.  
  108. # use the first device since this is a simple test script
  109. # (you'll want to modify this is using multiple NCS devices)
  110. print("[INFO] found {} devices. device0 will be used. "
  111. "opening device0...".format(len(devices)))
  112. device = mvnc.Device(devices[0])
  113. device.OpenDevice()
  114.  
  115. # open the CNN graph file
  116. print("[INFO] loading the graph file into RPi memory...")
  117. with open(args["graph"], mode="rb") as f:
  118. graph_in_memory = f.read()
  119.  
  120. # load the graph into the NCS
  121. print("[INFO] allocating the graph on the NCS...")
  122. graph = device.AllocateGraph(graph_in_memory)
  123.  
  124. # open a pointer to the video stream thread and allow the buffer to
  125. # start to fill, then start the FPS counter
  126. print("[INFO] starting the video stream and FPS counter...")
  127. vs = VideoStream(usePiCamera=True).start()
  128. time.sleep(1)
  129. fps = FPS().start()
  130.  
  131. # loop over frames from the video file stream
  132. while True:
  133. try:
  134. # grab the frame from the threaded video stream
  135. # make a copy of the frame and resize it for display/video purposes
  136. frame = vs.read()
  137. image_for_result = frame.copy()
  138. image_for_result = cv2.resize(image_for_result, DISPLAY_DIMS)
  139.  
  140. # use the NCS to acquire predictions
  141. predictions = predict(frame, graph)
  142.  
  143. # loop over our predictions
  144. for (i, pred) in enumerate(predictions):
  145. # extract prediction data for readability
  146. (pred_class, pred_conf, pred_boxpts) = pred
  147.  
  148. # filter out weak detections by ensuring the `confidence`
  149. # is greater than the minimum confidence
  150. if pred_conf > args["confidence"]:
  151. # print prediction to terminal
  152. print("[INFO] Prediction #{}: class={}, confidence={}, "
  153. "boxpoints={}".format(i, CLASSES[pred_class], pred_conf,
  154. pred_boxpts))
  155.  
  156. # check if we should show the prediction data
  157. # on the frame
  158. if args["display"] > 0:
  159. # build a label consisting of the predicted class and
  160. # associated probability
  161. label = "{}: {:.2f}%".format(CLASSES[pred_class],
  162. pred_conf * 100)
  163.  
  164. # extract information from the prediction boxpoints
  165. (ptA, ptB) = (pred_boxpts[0], pred_boxpts[1])
  166. ptA = (ptA[0] * DISP_MULTIPLIER, ptA[1] * DISP_MULTIPLIER)
  167. ptB = (ptB[0] * DISP_MULTIPLIER, ptB[1] * DISP_MULTIPLIER)
  168. (startX, startY) = (ptA[0], ptA[1])
  169. y = startY - 15 if startY - 15 > 15 else startY + 15
  170.  
  171. # display the rectangle and label text
  172. cv2.rectangle(image_for_result, ptA, ptB,
  173. COLORS[pred_class], 2)
  174. cv2.putText(image_for_result, label, (startX, y),
  175. cv2.FONT_HERSHEY_SIMPLEX, 1, COLORS[pred_class], 3)
  176.  
  177. # check if we should display the frame on the screen
  178. # with prediction data (you can achieve faster FPS if you
  179. # do not output to the screen)
  180. if args["display"] > 0:
  181. # display the frame to the screen
  182. cv2.imshow("Output", image_for_result)
  183. key = cv2.waitKey(1) & 0xFF
  184.  
  185. # if the `q` key was pressed, break from the loop
  186. if key == ord("q"):
  187. break
  188.  
  189. # update the FPS counter
  190. fps.update()
  191.  
  192. # if "ctrl+c" is pressed in the terminal, break from the loop
  193. except KeyboardInterrupt:
  194. break
  195.  
  196. # if there's a problem reading a frame, break gracefully
  197. except AttributeError:
  198. break
  199.  
  200. # stop the FPS counter timer
  201. fps.stop()
  202.  
  203. # destroy all windows if we are displaying them
  204. if args["display"] > 0:
  205. cv2.destroyAllWindows()
  206.  
  207. # stop the video stream
  208. vs.stop()
  209.  
  210. # clean up the graph and device
  211. graph.DeallocateGraph()
  212. device.CloseDevice()
  213.  
  214. # display FPS information
  215. print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
  216. print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement