Guest User

Untitled

a guest
Dec 28th, 2023
144
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.87 KB | None | 0 0
  1. import tkinter as tk
  2. from tkinter import Canvas
  3. from tkinter import Label
  4. import cv2
  5. import depthai as dai
  6. from PIL import Image, ImageTk
  7. from threading import Thread
  8. import argparse
  9. from pathlib import Path
  10. import json
  11. import time
  12. import numpy as np
  13. import blobconverter
  14. import gc
  15. from queue import Queue
  16.  
  17.  
  18. # parse arguments
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument("-m", "--model", help="Provide model name or model path for inference",
  21. default='yolov4_tiny_coco_416x416', type=str)
  22. parser.add_argument("-c", "--config", help="Provide config path for inference",
  23. default='json/yolov4-tiny.json', type=str)
  24. parser.add_argument('-ff', '--full_frame', action="store_true", help="Perform tracking on full RGB frame", default=False)
  25. args = parser.parse_args()
  26.  
  27. fullFrameTracking = args.full_frame
  28.  
  29. # parse config
  30. configPath = Path(args.config)
  31. if not configPath.exists():
  32. raise ValueError("Path {} does not exist!".format(configPath))
  33.  
  34. with configPath.open() as f:
  35. config = json.load(f)
  36. nnConfig = config.get("nn_config", {})
  37.  
  38. # parse input shape
  39. if "input_size" in nnConfig:
  40. W, H = tuple(map(int, nnConfig.get("input_size").split('x')))
  41.  
  42. # extract metadata
  43. metadata = nnConfig.get("NN_specific_metadata", {})
  44. classes = metadata.get("classes", {})
  45. coordinates = metadata.get("coordinates", {})
  46. anchors = metadata.get("anchors", {})
  47. anchorMasks = metadata.get("anchor_masks", {})
  48. iouThreshold = metadata.get("iou_threshold", {})
  49. confidenceThreshold = metadata.get("confidence_threshold", {})
  50.  
  51. print(metadata)
  52.  
  53. # parse labels
  54. nnMappings = config.get("mappings", {})
  55. labels = nnMappings.get("labels", {})
  56.  
  57. # get model path
  58. nnPath = args.model
  59. if not Path(nnPath).exists():
  60. print("No blob found at {}. Looking into DepthAI model zoo.".format(nnPath))
  61. nnPath = str(blobconverter.from_zoo(args.model, shaves=6, zoo_type="depthai", use_cache=True))
  62. # sync outputs
  63. syncNN = True
  64.  
  65.  
  66. def createPipeline():
  67. # Create pipeline
  68. pipeline = dai.Pipeline()
  69.  
  70. # Define sources and outputs
  71. camRgb = pipeline.create(dai.node.ColorCamera)
  72. detectionNetwork = pipeline.create(dai.node.YoloDetectionNetwork)
  73. objectTracker = pipeline.create(dai.node.ObjectTracker)
  74.  
  75. xoutRgb = pipeline.create(dai.node.XLinkOut)
  76. # nnOut = pipeline.create(dai.node.XLinkOut)
  77. trackerOut = pipeline.create(dai.node.XLinkOut)
  78.  
  79. xoutRgb.setStreamName("rgb")
  80. # nnOut.setStreamName("nn")
  81. trackerOut.setStreamName("tracklets")
  82.  
  83. streams = ("rgb")
  84.  
  85. # Properties
  86. camRgb.setPreviewSize(320, 320)
  87. camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P)
  88. camRgb.setInterleaved(False)
  89. camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR)
  90. camRgb.setFps(25)
  91.  
  92. # Network specific settings
  93. detectionNetwork.setConfidenceThreshold(confidenceThreshold)
  94. detectionNetwork.setNumClasses(classes)
  95. detectionNetwork.setCoordinateSize(coordinates)
  96. detectionNetwork.setAnchors(anchors)
  97. detectionNetwork.setAnchorMasks(anchorMasks)
  98. detectionNetwork.setIouThreshold(iouThreshold)
  99. detectionNetwork.setBlobPath(nnPath)
  100. detectionNetwork.setNumInferenceThreads(2)
  101. detectionNetwork.input.setBlocking(False)
  102.  
  103. # possible tracking types: ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS, SHORT_TERM_IMAGELESS, SHORT_TERM_KCF
  104. objectTracker.setTrackerType(dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM)
  105. # take the smallest ID when new object is tracked, possible options: SMALLEST_ID, UNIQUE_ID
  106. objectTracker.setTrackerIdAssignmentPolicy(dai.TrackerIdAssignmentPolicy.SMALLEST_ID)
  107. # Change this to track more objects
  108. objectTracker.setMaxObjectsToTrack(5)
  109. #Above this threshold the detected objects will be tracked. Default 0, all image detections are tracked.
  110. objectTracker.setTrackerThreshold(.87)
  111.  
  112. # Linking
  113. camRgb.preview.link(detectionNetwork.input)
  114. objectTracker.passthroughTrackerFrame.link(xoutRgb.input)
  115.  
  116. if fullFrameTracking:
  117. camRgb.video.link(objectTracker.inputTrackerFrame)
  118. else:
  119. detectionNetwork.passthrough.link(objectTracker.inputTrackerFrame)
  120.  
  121. detectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
  122. detectionNetwork.out.link(objectTracker.inputDetections)
  123. objectTracker.out.link(trackerOut.input)
  124.  
  125. return pipeline, streams
  126.  
  127.  
  128. def run(pipeline):
  129. # Connect to device and start pipeline
  130. with dai.Device(pipeline) as device:
  131. # Output queues will be used to get the rgb frames and nn data from the outputs defined above
  132. qRgb = device.getOutputQueue(name="rgb", maxSize=4, blocking=False)
  133. tracklets = device.getOutputQueue("tracklets", 4, False)
  134.  
  135. frame = None
  136. startTime = time.monotonic()
  137. counter = 0
  138.  
  139. # nn data, being the bounding box locations, are in <0..1> range - they need to be normalized with frame width/height
  140. def frameNorm(frame, bbox):
  141. normVals = np.full(len(bbox), frame.shape[0])
  142. normVals[::2] = frame.shape[1]
  143. return (np.clip(np.array(bbox), 0, 1) * normVals).astype(int)
  144.  
  145. def displayComponents(helmet_color: str, vest_color: str):
  146. canvas.create_rectangle(screen_width-520,400,screen_width-300,450, fill=helmet_color)
  147. canvas.create_rectangle(screen_width - 520, 500, screen_width - 300, 550, fill=vest_color)
  148.  
  149. def displayCompBoth(text: str,width: int):
  150. ppe_det.config(text=text)
  151. ppe_det.place(width=width)
  152.  
  153. def displayBoundingBoxes(frame, label, t, x1, y1, x2, y2):
  154. cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  155. # cv2.putText(frame, f"ID: {[t.id]}", (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  156. # cv2.putText(frame, t.status.name, (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
  157. cv2.rectangle(frame, (x1, y1), (x2, y2), color, cv2.FONT_HERSHEY_SIMPLEX)
  158.  
  159. def placeFrames(frame, image_label):
  160. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  161. image = Image.fromarray(frame)
  162. img_width = screen_width // 2
  163. img_height = screen_height // 2
  164. image = image.resize((640, 640))
  165. image = ImageTk.PhotoImage(image)
  166. image_label.config(image=image)
  167. image_label.image = image
  168. image_label.place(x=300, y=300)
  169.  
  170. def displayFrame(name, frame, detections, image_label):
  171. helmet_color = 'red'
  172. vest_color = 'red'
  173. if len(detections) < 1:
  174. helmet_color = 'red'
  175. vest_color = 'red'
  176. displayComponents(helmet_color, vest_color)
  177. print("Not detecting")
  178. else:
  179. for t in trackletsData:
  180. roi = t.roi.denormalize(frame.shape[1], frame.shape[0])
  181. x1 = int(roi.topLeft().x)
  182. y1 = int(roi.topLeft().y)
  183. x2 = int(roi.bottomRight().x)
  184. y2 = int(roi.bottomRight().y)
  185.  
  186. try:
  187. label = labels[t.label]
  188. except:
  189. label = t.label
  190. if label == "Helmet":
  191. displayBoundingBoxes(frame, label, t, x1, y1, x2, y2)
  192. helmet_color = 'green'
  193. if label == "Vest":
  194. displayBoundingBoxes(frame, label, t, x1, y1, x2, y2)
  195. vest_color = 'green'
  196.  
  197. displayComponents(helmet_color,vest_color)
  198. if helmet_color == 'green' and vest_color == 'green':
  199. displayCompBoth('OK', 150)
  200. else:
  201. displayCompBoth('NOT COMPLETE', 150)
  202.  
  203.  
  204. # initialize and place frames
  205. placeFrames(frame, image_label)
  206.  
  207. # Update the Tkinter Window
  208. window.update()
  209. window.update_idletasks()
  210. gc.collect()
  211.  
  212. while True:
  213. inRgb = qRgb.get()
  214. track = tracklets.get()
  215.  
  216. counter += 1
  217. current_time = time.monotonic()
  218. if (current_time - startTime) > 1:
  219. fps = counter / (current_time - startTime)
  220. counter = 0
  221. startTime = current_time
  222. color = (255, 0, 0)
  223. frame = inRgb.getCvFrame()
  224. trackletsData = track.tracklets
  225.  
  226. displayFrame('rgb',frame,trackletsData, image_label)
  227.  
  228. def _from_rgb(rgb):
  229. """translates an rgb tuple of int to a tkinter friendly color code
  230. """
  231. return "#%02x%02x%02x" % rgb
  232.  
  233.  
  234.  
  235. if __name__ == '__main__':
  236. pipeline, streamNames = createPipeline()
  237.  
  238. window = tk.Tk()
  239. screen_width = window.winfo_screenwidth()
  240. screen_height = window.winfo_screenheight()
  241. window.geometry(f"{screen_width}x{screen_height}")
  242. window.title("PPE Detection")
  243.  
  244. # Create a canvas widget
  245. canvas = Canvas(window, width=500, height=300)
  246. canvas.pack(expand=tk.YES, fill=tk.BOTH)
  247.  
  248. # Draw Vertical line - divider on canvas
  249. print("CANVAS WINFO HEIGHT: ", canvas.winfo_height())
  250. canvas.create_line((screen_width // 2) + 200, 0, (screen_width // 2) + 200, screen_height,
  251. fill=_from_rgb((30, 113, 183)), width=3)
  252.  
  253.  
  254. # Create Image Label
  255. image_label = tk.Label(window)
  256. '''
  257. CLASSES
  258. '''
  259. ## Helmet
  260. canvas.create_rectangle(screen_width-520,400,screen_width-300,450)
  261. canvas.create_text(screen_width-655, 410, text="HELMET", font=("Arial", 20), anchor=tk.NW)
  262. # ## Vest
  263. canvas.create_rectangle(screen_width-520,500,screen_width-300,550)
  264. canvas.create_text(screen_width-655, 510, text="VEST", font=("Arial", 20), anchor=tk.NW)
  265. # ## Both
  266. canvas.create_text(screen_width-655, 610, text="PPE:", font=("Arial", 20), anchor=tk.NW)
  267. canvas.create_rectangle(screen_width-520,600,screen_width-300,650)
  268. ppe_det = tk.Label(window, text='-', font=("Arial, 15"))
  269. ppe_det.place(x=screen_width-490, y=610, width=150, height=30)
  270.  
  271. # Place logo
  272. logo = Image.open("./logo/logo.jpg")
  273. logo = logo.resize((180, 80))
  274.  
  275. test = ImageTk.PhotoImage(logo)
  276. logo_label = tk.Label(image=test)
  277. logo_label.image = test
  278. logo_label.place(x=0, y=0)
  279.  
  280.  
  281. def callback(*args):
  282. global currentStream
  283. currentStream = window.getvar(args[0])
  284. cv2.destroyAllWindows()
  285.  
  286.  
  287. currentStream = streamNames[0]
  288. print("CURRENT STREAM: ", currentStream)
  289.  
  290. currentStreamVar = tk.StringVar(window)
  291. currentStreamVar.set(currentStream) # default value
  292. currentStreamVar.trace_add("write", callback)
  293.  
  294.  
  295. thread = Thread(target=run, args=(pipeline,))
  296. thread.daemon = True
  297. thread.start()
  298.  
  299. window.mainloop()
  300.  
Advertisement
Add Comment
Please, Sign In to add comment