Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tkinter as tk
- from tkinter import Canvas
- from tkinter import Label
- import cv2
- import depthai as dai
- from PIL import Image, ImageTk
- from threading import Thread
- import argparse
- from pathlib import Path
- import json
- import time
- import numpy as np
- import blobconverter
- import gc
- from queue import Queue
- # parse arguments
- parser = argparse.ArgumentParser()
- parser.add_argument("-m", "--model", help="Provide model name or model path for inference",
- default='yolov4_tiny_coco_416x416', type=str)
- parser.add_argument("-c", "--config", help="Provide config path for inference",
- default='json/yolov4-tiny.json', type=str)
- parser.add_argument('-ff', '--full_frame', action="store_true", help="Perform tracking on full RGB frame", default=False)
- args = parser.parse_args()
- fullFrameTracking = args.full_frame
- # parse config
- configPath = Path(args.config)
- if not configPath.exists():
- raise ValueError("Path {} does not exist!".format(configPath))
- with configPath.open() as f:
- config = json.load(f)
- nnConfig = config.get("nn_config", {})
- # parse input shape
- if "input_size" in nnConfig:
- W, H = tuple(map(int, nnConfig.get("input_size").split('x')))
- # extract metadata
- metadata = nnConfig.get("NN_specific_metadata", {})
- classes = metadata.get("classes", {})
- coordinates = metadata.get("coordinates", {})
- anchors = metadata.get("anchors", {})
- anchorMasks = metadata.get("anchor_masks", {})
- iouThreshold = metadata.get("iou_threshold", {})
- confidenceThreshold = metadata.get("confidence_threshold", {})
- print(metadata)
- # parse labels
- nnMappings = config.get("mappings", {})
- labels = nnMappings.get("labels", {})
- # get model path
- nnPath = args.model
- if not Path(nnPath).exists():
- print("No blob found at {}. Looking into DepthAI model zoo.".format(nnPath))
- nnPath = str(blobconverter.from_zoo(args.model, shaves=6, zoo_type="depthai", use_cache=True))
- # sync outputs
- syncNN = True
- def createPipeline():
- # Create pipeline
- pipeline = dai.Pipeline()
- # Define sources and outputs
- camRgb = pipeline.create(dai.node.ColorCamera)
- detectionNetwork = pipeline.create(dai.node.YoloDetectionNetwork)
- objectTracker = pipeline.create(dai.node.ObjectTracker)
- xoutRgb = pipeline.create(dai.node.XLinkOut)
- # nnOut = pipeline.create(dai.node.XLinkOut)
- trackerOut = pipeline.create(dai.node.XLinkOut)
- xoutRgb.setStreamName("rgb")
- # nnOut.setStreamName("nn")
- trackerOut.setStreamName("tracklets")
- streams = ("rgb")
- # Properties
- camRgb.setPreviewSize(320, 320)
- camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P)
- camRgb.setInterleaved(False)
- camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR)
- camRgb.setFps(25)
- # Network specific settings
- detectionNetwork.setConfidenceThreshold(confidenceThreshold)
- detectionNetwork.setNumClasses(classes)
- detectionNetwork.setCoordinateSize(coordinates)
- detectionNetwork.setAnchors(anchors)
- detectionNetwork.setAnchorMasks(anchorMasks)
- detectionNetwork.setIouThreshold(iouThreshold)
- detectionNetwork.setBlobPath(nnPath)
- detectionNetwork.setNumInferenceThreads(2)
- detectionNetwork.input.setBlocking(False)
- # possible tracking types: ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS, SHORT_TERM_IMAGELESS, SHORT_TERM_KCF
- objectTracker.setTrackerType(dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM)
- # take the smallest ID when new object is tracked, possible options: SMALLEST_ID, UNIQUE_ID
- objectTracker.setTrackerIdAssignmentPolicy(dai.TrackerIdAssignmentPolicy.SMALLEST_ID)
- # Change this to track more objects
- objectTracker.setMaxObjectsToTrack(5)
- #Above this threshold the detected objects will be tracked. Default 0, all image detections are tracked.
- objectTracker.setTrackerThreshold(.87)
- # Linking
- camRgb.preview.link(detectionNetwork.input)
- objectTracker.passthroughTrackerFrame.link(xoutRgb.input)
- if fullFrameTracking:
- camRgb.video.link(objectTracker.inputTrackerFrame)
- else:
- detectionNetwork.passthrough.link(objectTracker.inputTrackerFrame)
- detectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
- detectionNetwork.out.link(objectTracker.inputDetections)
- objectTracker.out.link(trackerOut.input)
- return pipeline, streams
- def run(pipeline):
- # Connect to device and start pipeline
- with dai.Device(pipeline) as device:
- # Output queues will be used to get the rgb frames and nn data from the outputs defined above
- qRgb = device.getOutputQueue(name="rgb", maxSize=4, blocking=False)
- tracklets = device.getOutputQueue("tracklets", 4, False)
- frame = None
- startTime = time.monotonic()
- counter = 0
- # nn data, being the bounding box locations, are in <0..1> range - they need to be normalized with frame width/height
- def frameNorm(frame, bbox):
- normVals = np.full(len(bbox), frame.shape[0])
- normVals[::2] = frame.shape[1]
- return (np.clip(np.array(bbox), 0, 1) * normVals).astype(int)
- def displayComponents(helmet_color: str, vest_color: str):
- canvas.create_rectangle(screen_width-520,400,screen_width-300,450, fill=helmet_color)
- canvas.create_rectangle(screen_width - 520, 500, screen_width - 300, 550, fill=vest_color)
- def displayCompBoth(text: str,width: int):
- ppe_det.config(text=text)
- ppe_det.place(width=width)
- def displayBoundingBoxes(frame, label, t, x1, y1, x2, y2):
- cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
- # cv2.putText(frame, f"ID: {[t.id]}", (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
- # cv2.putText(frame, t.status.name, (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, cv2.FONT_HERSHEY_SIMPLEX)
- def placeFrames(frame, image_label):
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- image = Image.fromarray(frame)
- img_width = screen_width // 2
- img_height = screen_height // 2
- image = image.resize((640, 640))
- image = ImageTk.PhotoImage(image)
- image_label.config(image=image)
- image_label.image = image
- image_label.place(x=300, y=300)
- def displayFrame(name, frame, detections, image_label):
- helmet_color = 'red'
- vest_color = 'red'
- if len(detections) < 1:
- helmet_color = 'red'
- vest_color = 'red'
- displayComponents(helmet_color, vest_color)
- print("Not detecting")
- else:
- for t in trackletsData:
- roi = t.roi.denormalize(frame.shape[1], frame.shape[0])
- x1 = int(roi.topLeft().x)
- y1 = int(roi.topLeft().y)
- x2 = int(roi.bottomRight().x)
- y2 = int(roi.bottomRight().y)
- try:
- label = labels[t.label]
- except:
- label = t.label
- if label == "Helmet":
- displayBoundingBoxes(frame, label, t, x1, y1, x2, y2)
- helmet_color = 'green'
- if label == "Vest":
- displayBoundingBoxes(frame, label, t, x1, y1, x2, y2)
- vest_color = 'green'
- displayComponents(helmet_color,vest_color)
- if helmet_color == 'green' and vest_color == 'green':
- displayCompBoth('OK', 150)
- else:
- displayCompBoth('NOT COMPLETE', 150)
- # initialize and place frames
- placeFrames(frame, image_label)
- # Update the Tkinter Window
- window.update()
- window.update_idletasks()
- gc.collect()
- while True:
- inRgb = qRgb.get()
- track = tracklets.get()
- counter += 1
- current_time = time.monotonic()
- if (current_time - startTime) > 1:
- fps = counter / (current_time - startTime)
- counter = 0
- startTime = current_time
- color = (255, 0, 0)
- frame = inRgb.getCvFrame()
- trackletsData = track.tracklets
- displayFrame('rgb',frame,trackletsData, image_label)
- def _from_rgb(rgb):
- """translates an rgb tuple of int to a tkinter friendly color code
- """
- return "#%02x%02x%02x" % rgb
- if __name__ == '__main__':
- pipeline, streamNames = createPipeline()
- window = tk.Tk()
- screen_width = window.winfo_screenwidth()
- screen_height = window.winfo_screenheight()
- window.geometry(f"{screen_width}x{screen_height}")
- window.title("PPE Detection")
- # Create a canvas widget
- canvas = Canvas(window, width=500, height=300)
- canvas.pack(expand=tk.YES, fill=tk.BOTH)
- # Draw Vertical line - divider on canvas
- print("CANVAS WINFO HEIGHT: ", canvas.winfo_height())
- canvas.create_line((screen_width // 2) + 200, 0, (screen_width // 2) + 200, screen_height,
- fill=_from_rgb((30, 113, 183)), width=3)
- # Create Image Label
- image_label = tk.Label(window)
- '''
- CLASSES
- '''
- ## Helmet
- canvas.create_rectangle(screen_width-520,400,screen_width-300,450)
- canvas.create_text(screen_width-655, 410, text="HELMET", font=("Arial", 20), anchor=tk.NW)
- # ## Vest
- canvas.create_rectangle(screen_width-520,500,screen_width-300,550)
- canvas.create_text(screen_width-655, 510, text="VEST", font=("Arial", 20), anchor=tk.NW)
- # ## Both
- canvas.create_text(screen_width-655, 610, text="PPE:", font=("Arial", 20), anchor=tk.NW)
- canvas.create_rectangle(screen_width-520,600,screen_width-300,650)
- ppe_det = tk.Label(window, text='-', font=("Arial, 15"))
- ppe_det.place(x=screen_width-490, y=610, width=150, height=30)
- # Place logo
- logo = Image.open("./logo/logo.jpg")
- logo = logo.resize((180, 80))
- test = ImageTk.PhotoImage(logo)
- logo_label = tk.Label(image=test)
- logo_label.image = test
- logo_label.place(x=0, y=0)
- def callback(*args):
- global currentStream
- currentStream = window.getvar(args[0])
- cv2.destroyAllWindows()
- currentStream = streamNames[0]
- print("CURRENT STREAM: ", currentStream)
- currentStreamVar = tk.StringVar(window)
- currentStreamVar.set(currentStream) # default value
- currentStreamVar.trace_add("write", callback)
- thread = Thread(target=run, args=(pipeline,))
- thread.daemon = True
- thread.start()
- window.mainloop()
Advertisement
Add Comment
Please, Sign In to add comment