Data hosted with ♥ by Pastebin.com - Download Raw - See Original
  1. import json
  2. import cv2
  3. import depthai as dai
  4. import numpy as np
  5. import time
  6.  
  7.  
  8. from spr.sprmai.helper_classes import Paths
  9.  
  10.  
  11. nn_path = Paths.yolov5_oaknn
  12. config_path = Paths.yolov5_oakcfg
  13.  
  14. with config_path.open() as f:
  15.     config = json.load(f)
  16. nn_config = config.get("nn_config", {})
  17.  
  18. # extract metadata from config
  19. metadata = nn_config.get("NN_specific_metadata", {})
  20. classes = metadata.get("classes", {})
  21. coordinates = metadata.get("coordinates", {})
  22. anchors = metadata.get("anchors", {})
  23. anchor_masks = metadata.get("anchor_masks", {})
  24. iou_threshold = metadata.get("iou_threshold", {})
  25. confidence_threshold = metadata.get(
  26.     "confidence_threshold", {})
  27.  
  28.  
  29. # Create pipeline
  30. pipeline = dai.Pipeline()
  31.  
  32. # Define sources and outputs
  33. camRgb = pipeline.create(dai.node.ColorCamera)
  34. detectionNetwork = pipeline.create(dai.node.YoloSpatialDetectionNetwork)
  35. objectTracker = pipeline.create(dai.node.ObjectTracker)
  36.  
  37. xlinkOut = pipeline.create(dai.node.XLinkOut)
  38. trackerOut = pipeline.create(dai.node.XLinkOut)
  39.  
  40. xlinkOut.setStreamName("preview")
  41. trackerOut.setStreamName("tracklets")
  42.  
  43. # Properties
  44. camRgb.setPreviewSize(416, 416)
  45. camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P)
  46. camRgb.setInterleaved(False)
  47. camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR)
  48. camRgb.setFps(40)
  49.  
  50.  
  51. # testing MobileNet DetectionNetwork
  52. detectionNetwork.setBlobPath(nn_path)
  53. detectionNetwork.setConfidenceThreshold(
  54.     confidence_threshold)
  55. detectionNetwork.input.setBlocking(False)
  56. detectionNetwork.setBoundingBoxScaleFactor(0.5)
  57. detectionNetwork.setDepthLowerThreshold(100)
  58. detectionNetwork.setDepthUpperThreshold(5000)
  59.  
  60. # Yolo specific parameters
  61. detectionNetwork.setNumClasses(classes)
  62. detectionNetwork.setCoordinateSize(coordinates)
  63. detectionNetwork.setAnchors(anchors)
  64. detectionNetwork.setAnchorMasks(anchor_masks)
  65. detectionNetwork.setIouThreshold(iou_threshold)
  66.  
  67. objectTracker.setDetectionLabelsToTrack([1])  # track only person
  68. # possible tracking types: ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS, SHORT_TERM_IMAGELESS, SHORT_TERM_KCF
  69. objectTracker.setTrackerType(dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM)
  70. # take the smallest ID when new object is tracked, possible options: SMALLEST_ID, UNIQUE_ID
  71. objectTracker.setTrackerIdAssignmentPolicy(
  72.     dai.TrackerIdAssignmentPolicy.SMALLEST_ID)
  73.  
  74. # Linking
  75. camRgb.preview.link(detectionNetwork.input)
  76. objectTracker.passthroughTrackerFrame.link(xlinkOut.input)
  77.  
  78.  
  79. detectionNetwork.passthrough.link(objectTracker.inputTrackerFrame)
  80.  
  81. detectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
  82. detectionNetwork.out.link(objectTracker.inputDetections)
  83. objectTracker.out.link(trackerOut.input)
  84.  
  85. # Connect to device and start pipeline
  86. with dai.Device(pipeline) as device:
  87.  
  88.     preview = device.getOutputQueue("preview", 4, False)
  89.     print(preview)
  90.     tracklets = device.getOutputQueue("tracklets", 4, False)
  91.     print(tracklets)
  92.     startTime = time.monotonic()
  93.     counter = 0
  94.     fps = 0
  95.     frame = None
  96.  
  97.     while True:
  98.  
  99.         imgFrame = preview.get()
  100.         print(f'Getting imgFrame{imgFrame}')
  101.         track = tracklets.get()
  102.         print(f'Getting track{track}')
  103.  
  104.         color = (255, 0, 0)
  105.         frame = imgFrame.getCvFrame()
  106.         trackletsData = track.tracklets
  107.         print(f'Getting tracklets data{trackletsData}')
  108.         for t in trackletsData:
  109.             roi = t.roi.denormalize(frame.shape[1], frame.shape[0])
  110.             x1 = int(roi.topLeft().x)
  111.             y1 = int(roi.topLeft().y)
  112.             x2 = int(roi.bottomRight().x)
  113.             y2 = int(roi.bottomRight().y)
  114.  
  115.             label = t.label
  116.  
  117.             cv2.putText(frame, str(label), (x1 + 10, y1 + 20),
  118.                         cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  119.             cv2.putText(
  120.                 frame, f"ID: {[t.id]}", (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  121.             cv2.putText(frame, t.status.name, (x1 + 10, y1 + 50),
  122.                         cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  123.             cv2.rectangle(frame, (x1, y1), (x2, y2),
  124.                           color, cv2.FONT_HERSHEY_SIMPLEX)
  125.  
  126.         cv2.imshow("tracker", frame)
  127.  
  128.         if cv2.waitKey(1) == ord('q'):
  129.             break
  130.