Advertisement
Guest User

Face Fusion no NSFW content_analyser.py

a guest
Apr 10th, 2024
7,013
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.40 KB | None | 0 0
  1. from typing import Any
  2. from functools import lru_cache
  3. from time import sleep
  4. import threading
  5. import cv2
  6. import numpy
  7. import onnxruntime
  8. from tqdm import tqdm
  9.  
  10. import facefusion.globals
  11. from facefusion import process_manager, wording
  12. from facefusion.typing import VisionFrame, ModelSet, Fps
  13. from facefusion.execution import apply_execution_provider_options
  14. from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_video_fps
  15. from facefusion.filesystem import resolve_relative_path, is_file
  16. from facefusion.download import conditional_download
  17.  
  18. CONTENT_ANALYSER = None
  19. THREAD_LOCK : threading.Lock = threading.Lock()
  20. MODELS : ModelSet =\
  21. {
  22. 'open_nsfw':
  23. {
  24. 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/open_nsfw.onnx',
  25. 'path': resolve_relative_path('../.assets/models/open_nsfw.onnx')
  26. }
  27. }
  28. PROBABILITY_LIMIT = 0.80
  29. RATE_LIMIT = 10
  30. STREAM_COUNTER = 0
  31.  
  32.  
  33. def get_content_analyser() -> Any:
  34. global CONTENT_ANALYSER
  35.  
  36. with THREAD_LOCK:
  37. while process_manager.is_checking():
  38. sleep(0.5)
  39. if CONTENT_ANALYSER is None:
  40. model_path = MODELS.get('open_nsfw').get('path')
  41. CONTENT_ANALYSER = onnxruntime.InferenceSession(model_path, providers = apply_execution_provider_options(facefusion.globals.execution_providers))
  42. return CONTENT_ANALYSER
  43.  
  44.  
  45. def clear_content_analyser() -> None:
  46. global CONTENT_ANALYSER
  47.  
  48. CONTENT_ANALYSER = None
  49.  
  50.  
  51. def pre_check() -> bool:
  52. download_directory_path = resolve_relative_path('../.assets/models')
  53. model_url = MODELS.get('open_nsfw').get('url')
  54. model_path = MODELS.get('open_nsfw').get('path')
  55.  
  56. if not facefusion.globals.skip_download:
  57. process_manager.check()
  58. conditional_download(download_directory_path, [ model_url ])
  59. process_manager.end()
  60. return is_file(model_path)
  61.  
  62.  
  63. def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:
  64. global STREAM_COUNTER
  65.  
  66. STREAM_COUNTER = STREAM_COUNTER + 1
  67. if STREAM_COUNTER % int(video_fps) == 0:
  68. return analyse_frame(vision_frame)
  69. return False
  70.  
  71.  
  72. def analyse_frame(vision_frame : VisionFrame) -> bool:
  73. # Always return False to indicate that the content is safe
  74. return False
  75.  
  76.  
  77. def prepare_frame(vision_frame : VisionFrame) -> VisionFrame:
  78. vision_frame = cv2.resize(vision_frame, (224, 224)).astype(numpy.float32)
  79. vision_frame -= numpy.array([ 104, 117, 123 ]).astype(numpy.float32)
  80. vision_frame = numpy.expand_dims(vision_frame, axis = 0)
  81. return vision_frame
  82.  
  83.  
  84. @lru_cache(maxsize = None)
  85. def analyse_image(image_path : str) -> bool:
  86. frame = read_image(image_path)
  87. return analyse_frame(frame)
  88.  
  89.  
  90. @lru_cache(maxsize = None)
  91. def analyse_video(video_path : str, start_frame : int, end_frame : int) -> bool:
  92. video_frame_total = count_video_frame_total(video_path)
  93. video_fps = detect_video_fps(video_path)
  94. frame_range = range(start_frame or 0, end_frame or video_frame_total)
  95. rate = 0.0
  96. counter = 0
  97.  
  98. with tqdm(total = len(frame_range), desc = wording.get('analysing'), unit = 'frame', ascii = ' =', disable = facefusion.globals.log_level in [ 'warn', 'error' ]) as progress:
  99. for frame_number in frame_range:
  100. if frame_number % int(video_fps) == 0:
  101. frame = get_video_frame(video_path, frame_number)
  102. if analyse_frame(frame):
  103. counter += 1
  104. rate = counter * int(video_fps) / len(frame_range) * 100
  105. progress.update()
  106. progress.set_postfix(rate = rate)
  107. return rate > RATE_LIMIT
  108.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement