Data hosted with ♥ by Pastebin.com - Download Raw - See Original
  1. #!/usr/bin/env python3
  2.  
  3. from pathlib import Path
  4. import cv2
  5. import depthai as dai
  6. import numpy as np
  7. import time
  8. import argparse
  9.  
  10. labelMap = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow",
  11.             "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
  12.  
  13. nnPathDefault = str((Path(__file__).parent / Path('models/mobilenet-ssd_openvino_2021.4_5shave.blob')).resolve().absolute())
  14.  
  15. # Create pipeline
  16. pipeline = dai.Pipeline()
  17.  
  18. # Define sources and outputs
  19. camRgb = pipeline.create(dai.node.ImageManip)
  20. spatialDetectionNetwork = pipeline.create(dai.node.MobileNetDetectionNetwork)
  21. objectTracker = pipeline.create(dai.node.ObjectTracker)
  22.  
  23. xoutFrame = pipeline.create(dai.node.XLinkOut)
  24. xinFrame = pipeline.create(dai.node.XLinkIn)
  25. trackerOut = pipeline.create(dai.node.XLinkOut)
  26.  
  27. xinFrame.setStreamName("inFrame")
  28. xoutFrame.setStreamName("preview")
  29. trackerOut.setStreamName("tracklets")
  30. # Properties
  31. camRgb.initialConfig.setResize(300,300)
  32. camRgb.initialConfig.setFrameType(dai.RawImgFrame.Type.BGR888p)
  33. camRgb.setKeepAspectRatio(True)
  34.  
  35. camRgb.initialConfig.setResizeThumbnail(300,300)
  36.  
  37. # setting node configs
  38.  
  39. spatialDetectionNetwork.setBlobPath(nnPathDefault)
  40. spatialDetectionNetwork.setConfidenceThreshold(0.5)
  41. spatialDetectionNetwork.input.setBlocking(False)
  42.  
  43. objectTracker.setTrackerType(dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM)
  44. objectTracker.setTrackerIdAssignmentPolicy(dai.TrackerIdAssignmentPolicy.SMALLEST_ID)
  45.  
  46. # Linking
  47.  
  48. xinFrame.out.link(camRgb.inputImage)
  49. camRgb.out.link(spatialDetectionNetwork.input)
  50.  
  51.  
  52. objectTracker.passthroughTrackerFrame.link(xoutFrame.input) #this function is used to show the tracking frame
  53. objectTracker.out.link(trackerOut.input)
  54. #link rgb camera's output to xoutRgb
  55.  
  56. spatialDetectionNetwork.passthrough.link(objectTracker.inputTrackerFrame)
  57.  
  58. spatialDetectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
  59. spatialDetectionNetwork.out.link(objectTracker.inputDetections)
  60.  
  61. def to_planar(arr: np.ndarray, shape: tuple) -> np.ndarray:
  62.         return cv2.resize(arr, shape).transpose(2, 0, 1).flatten()
  63. # Connect to device and start pipeline
  64. with dai.Device(pipeline) as device:
  65.  
  66.     cap = cv2.VideoCapture("walking.mp4")
  67.     qIn = device.getInputQueue(name="inFrame")
  68.  
  69.     preview = device.getOutputQueue("preview", 4, False)
  70.     tracklets = device.getOutputQueue("tracklets", 4, False)
  71.     startTime = time.monotonic()
  72.     counter = 0
  73.     fps = 0
  74.     color = (255, 255, 255)
  75.  
  76.     from threading import Thread
  77.  
  78.     def send_frames(queue, cap):
  79.         while True:
  80.             ret, rgb = cap.read()
  81.             if not ret:
  82.                 print("Can't receive frame (stream end?). Exiting ...")
  83.                 break
  84.  
  85.             rgbImg = dai.ImgFrame()
  86.             rgbImg.setData(to_planar(rgb, (300, 300)))
  87.             rgbImg.setType(dai.ImgFrame.Type.BGR888p)
  88.             rgbImg.setTimestamp(time.monotonic())
  89.             rgbImg.setWidth(300)
  90.             rgbImg.setHeight(300)
  91.             qIn.send(rgbImg)
  92.  
  93.     send_thread = Thread(target=send_frames, args=(qIn, cap,))
  94.     send_thread.start()
  95.  
  96.     while send_thread.is_alive():
  97.  
  98.         imgFrame = preview.get()
  99.         print("RGB Image Sent")
  100.         print("imgFrame received")
  101.         track = tracklets.get()
  102.         print("tracklets received")
  103.  
  104.         frame = imgFrame.getCvFrame()
  105.         trackletsData = track.tracklets
  106.         print("trackletsData", trackletsData)
  107.         cv2.imshow("tracker", frame)
  108.  
  109.         if cv2.waitKey(1) == ord('q'):
  110.             break