Guest User

nudenet detector

a guest
Apr 10th, 2021
1,762
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.27 KB | None | 0 0
  1. import os
  2. import cv2
  3. import pydload
  4. import logging
  5. import numpy as np
  6. import onnxruntime
  7. from progressbar import progressbar
  8.  
  9. from .detector_utils import preprocess_image
  10. from .video_utils import get_interest_frames_from_video
  11.  
  12.  
  13. def dummy(x):
  14.     return x
  15.  
  16.  
  17. FILE_URLS = {
  18.     "default": {
  19.         "checkpoint": "https://github.com/notAI-tech/NudeNet/releases/download/v0/detector_v2_default_checkpoint.onnx",
  20.         "classes": "https://github.com/notAI-tech/NudeNet/releases/download/v0/detector_v2_default_classes",
  21.     },
  22.     "base": {
  23.         "checkpoint": "https://github.com/notAI-tech/NudeNet/releases/download/v0/detector_v2_base_checkpoint.onnx",
  24.         "classes": "https://github.com/notAI-tech/NudeNet/releases/download/v0/detector_v2_base_classes",
  25.     },
  26. }
  27.  
  28.  
  29. class Detector:
  30.     detection_model = None
  31.     classes = None
  32.  
  33.     def __init__(self, model_name="default"):
  34.         """
  35.        model = Detector()
  36.        """
  37.         checkpoint_url = FILE_URLS[model_name]["checkpoint"]
  38.         classes_url = FILE_URLS[model_name]["classes"]
  39.  
  40.         home = os.path.expanduser("~")
  41.         model_folder = os.path.join(home, f".NudeNet/")
  42.         if not os.path.exists(model_folder):
  43.             os.makedirs(model_folder)
  44.  
  45.         checkpoint_name = os.path.basename(checkpoint_url)
  46.         checkpoint_path = os.path.join(model_folder, checkpoint_name)
  47.         classes_path = os.path.join(model_folder, "classes")
  48.  
  49.         if not os.path.exists(checkpoint_path):
  50.             print("Downloading the checkpoint to", checkpoint_path)
  51.             pydload.dload(checkpoint_url, save_to_path=checkpoint_path, max_time=None)
  52.  
  53.         if not os.path.exists(classes_path):
  54.             print("Downloading the classes list to", classes_path)
  55.             pydload.dload(classes_url, save_to_path=classes_path, max_time=None)
  56.  
  57.         self.detection_model = onnxruntime.InferenceSession(checkpoint_path)
  58.  
  59.         self.classes = [c.strip() for c in open(classes_path).readlines() if c.strip()]
  60.  
  61.     def detect_video(
  62.         self, video_path, mode="default", min_prob=0.6, batch_size=2, show_progress=True
  63.     ):
  64.         frame_indices, frames, fps, video_length = get_interest_frames_from_video(
  65.             video_path
  66.         )
  67.         logging.debug(
  68.             f"VIDEO_PATH: {video_path}, FPS: {fps}, Important frame indices: {frame_indices}, Video length: {video_length}"
  69.         )
  70.         if mode == "fast":
  71.             frames = [
  72.                 preprocess_image(frame, min_side=480, max_side=800) for frame in frames
  73.             ]
  74.         else:
  75.             frames = [preprocess_image(frame) for frame in frames]
  76.  
  77.         scale = frames[0][1]
  78.         frames = [frame[0] for frame in frames]
  79.         all_results = {
  80.             "metadata": {
  81.                 "fps": fps,
  82.                 "video_length": video_length,
  83.                 "video_path": video_path,
  84.             },
  85.             "preds": {},
  86.         }
  87.  
  88.         progress_func = progressbar
  89.  
  90.         if not show_progress:
  91.             progress_func = dummy
  92.  
  93.         for _ in progress_func(range(int(len(frames) / batch_size) + 1)):
  94.             batch = frames[:batch_size]
  95.             batch_indices = frame_indices[:batch_size]
  96.             frames = frames[batch_size:]
  97.             frame_indices = frame_indices[batch_size:]
  98.             if batch_indices:
  99.                 outputs = self.detection_model.run(
  100.                     [s_i.name for s_i in self.detection_model.get_outputs()],
  101.                     {self.detection_model.get_inputs()[0].name: np.asarray(batch)},
  102.                 )
  103.  
  104.                 labels = [op for op in outputs if op.dtype == "int32"][0]
  105.                 scores = [op for op in outputs if isinstance(op[0][0], np.float32)][0]
  106.                 boxes = [op for op in outputs if isinstance(op[0][0], np.ndarray)][0]
  107.  
  108.                 boxes /= scale
  109.                 for frame_index, frame_boxes, frame_scores, frame_labels in zip(
  110.                     frame_indices, boxes, scores, labels
  111.                 ):
  112.                     if frame_index not in all_results["preds"]:
  113.                         all_results["preds"][frame_index] = []
  114.  
  115.                     for box, score, label in zip(
  116.                         frame_boxes, frame_scores, frame_labels
  117.                     ):
  118.                         if score < min_prob:
  119.                             continue
  120.                         box = box.astype(int).tolist()
  121.                         label = self.classes[label]
  122.  
  123.                         all_results["preds"][frame_index].append(
  124.                             {
  125.                                 "box": [int(c) for c in box],
  126.                                 "score": float(score),
  127.                                 "label": label,
  128.                             }
  129.                         )
  130.  
  131.         return all_results
  132.  
  133.     def detect(self, img_path, mode="default", min_prob=None):
  134.         if mode == "fast":
  135.             image, scale = preprocess_image(img_path, min_side=480, max_side=800)
  136.             if not min_prob:
  137.                 min_prob = 0.5
  138.         else:
  139.             image, scale = preprocess_image(img_path)
  140.             if not min_prob:
  141.                 min_prob = 0.6
  142.  
  143.         outputs = self.detection_model.run(
  144.             [s_i.name for s_i in self.detection_model.get_outputs()],
  145.             {self.detection_model.get_inputs()[0].name: np.expand_dims(image, axis=0)},
  146.         )
  147.  
  148.         labels = [op for op in outputs if op.dtype == "int32"][0]
  149.         scores = [op for op in outputs if isinstance(op[0][0], np.float32)][0]
  150.         boxes = [op for op in outputs if isinstance(op[0][0], np.ndarray)][0]
  151.  
  152.         boxes /= scale
  153.         processed_boxes = []
  154.         for box, score, label in zip(boxes[0], scores[0], labels[0]):
  155.             if score < min_prob:
  156.                 continue
  157.             box = box.astype(int).tolist()
  158.             label = self.classes[label]
  159.             processed_boxes.append(
  160.                 {"box": [int(c) for c in box], "score": float(score), "label": label}
  161.             )
  162.  
  163.         return processed_boxes
  164.  
  165.     def censor(self, img_path, out_path=None, visualize=False, parts_to_blur=[],typec=[],pixs=[8,8]):
  166.         if not out_path and not visualize:
  167.             print(
  168.                 "No out_path passed and visualize is set to false. There is no point in running this function then."
  169.             )
  170.             return
  171.  
  172.         image = cv2.imread(img_path)
  173.         boxes = self.detect(img_path)
  174.  
  175.         if parts_to_blur:
  176.             boxes = [i["box"] for i in boxes if i["label"] in parts_to_blur]
  177.         else:
  178.             boxes = [i["box"] for i in boxes]
  179.        
  180.        
  181.         if typec=='blur':
  182.             for box in boxes:                
  183.                 topLeft = (box[0], box[1])
  184.  
  185.                 bottomRight = (box[2]  , box[3])
  186.                 x, y = topLeft[0], topLeft[1]
  187.                 w, h = bottomRight[0] - topLeft[0], bottomRight[1] - topLeft[1]
  188.  
  189.                
  190.                 ROI = image[y:y+h, x:x+w]
  191.                 blur = cv2.GaussianBlur(ROI, (101,101), 0)
  192.  
  193.  
  194.                
  195.                 image[y:y+h, x:x+w] = blur        
  196.  
  197.            
  198.         elif typec=='pix':
  199.             for box in boxes:
  200.                 topLeft = (box[0], box[1])
  201.  
  202.                 bottomRight = (box[2]  , box[3])
  203.                 x, y = topLeft[0], topLeft[1]
  204.                 w, h = bottomRight[0] - topLeft[0], bottomRight[1] - topLeft[1]
  205.  
  206.                
  207.                 ROI = image[y:y+h, x:x+w]
  208.                 w_p, h_p = pixs[0], pixs[1]
  209.  
  210.                
  211.                 temp = cv2.resize(ROI,(w_p,h_p) , interpolation=cv2.INTER_LINEAR)
  212.                 height, width = ROI.shape[:2]
  213.                 output = cv2.resize(temp, (width, height), interpolation=cv2.INTER_NEAREST)
  214.                 image[y:y+h, x:x+w] = output
  215.  
  216.         else:        
  217.             for box in boxes:
  218.                 part = image[box[1] : box[3], box[0] : box[2]]
  219.                 image = cv2.rectangle(
  220.                 image, (box[0], box[1]), (box[2], box[3]), (0, 0, 0), cv2.FILLED)
  221.  
  222.         if visualize:
  223.             cv2.imshow("Blurred image", image)
  224.             cv2.waitKey(0)
  225.  
  226.         if out_path:
  227.             cv2.imwrite(out_path, image)
  228.  
  229.  
  230. if __name__ == "__main__":
  231.     m = Detector()
  232.     print(m.detect("/Users/bedapudi/Desktop/n2.jpg"))
  233.  
Advertisement
Add Comment
Please, Sign In to add comment