Guest User

Untitled

a guest
Dec 16th, 2024
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.70 KB | None | 0 0
  1. """Utility functions for the tagger module"""
  2. import os
  3.  
  4. from typing import List, Dict
  5. from pathlib import Path
  6.  
  7. from modules import shared, scripts  # pylint: disable=import-error
  8. from modules.shared import models_path  # pylint: disable=import-error
  9.  
  10. default_ddp_path = Path(models_path, 'deepdanbooru')
  11. default_onnx_path = Path(models_path, 'TaggerOnnx')
  12. from tagger.preset import Preset  # pylint: disable=import-error
  13. from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, \
  14.                                 MLDanbooruInterrogator  # pylint: disable=E0401 # noqa: E501
  15. from tagger.interrogator import WaifuDiffusionInterrogator  # pylint: disable=E0401 # noqa: E501
  16.  
  17. preset = Preset(Path(scripts.basedir(), 'presets'))
  18.  
  19. interrogators: Dict[str, Interrogator] = {
  20.     'wd14-vit.v1': WaifuDiffusionInterrogator(
  21.         'WD14 ViT v1',
  22.         repo_id='SmilingWolf/wd-v1-4-vit-tagger'
  23.     ),
  24.     'wd14-vit.v2': WaifuDiffusionInterrogator(
  25.         'WD14 ViT v2',
  26.         repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2',
  27.     ),
  28.     'wd14-convnext.v1': WaifuDiffusionInterrogator(
  29.         'WD14 ConvNeXT v1',
  30.         repo_id='SmilingWolf/wd-v1-4-convnext-tagger'
  31.     ),
  32.     'wd14-convnext.v2': WaifuDiffusionInterrogator(
  33.         'WD14 ConvNeXT v2',
  34.         repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2',
  35.     ),
  36.     'wd14-convnextv2.v1': WaifuDiffusionInterrogator(
  37.         'WD14 ConvNeXTV2 v1',
  38.         # the name is misleading, but it's v1
  39.         repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2',
  40.     ),
  41.     'wd14-swinv2-v1': WaifuDiffusionInterrogator(
  42.         'WD14 SwinV2 v1',
  43.         # again misleading name
  44.         repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2',
  45.     ),
  46.     'wd-v1-4-moat-tagger.v2': WaifuDiffusionInterrogator(
  47.         'WD14 moat tagger v2',
  48.         repo_id='SmilingWolf/wd-v1-4-moat-tagger-v2'
  49.     ),
  50.     'wd-v1-4-vit-tagger.v3': WaifuDiffusionInterrogator(
  51.         'WD14 ViT v3',
  52.         repo_id='SmilingWolf/wd-vit-tagger-v3'
  53.     ),
  54.     'wd-v1-4-convnext-tagger.v3': WaifuDiffusionInterrogator(
  55.         'WD14 ConvNext v3',
  56.         repo_id='SmilingWolf/wd-convnext-tagger-v3'
  57.     ),
  58.     'wd-v1-4-swinv2-tagger.v3': WaifuDiffusionInterrogator(
  59.         'WD14 SwinV2 v3',
  60.         repo_id='SmilingWolf/wd-swinv2-tagger-v3'
  61.     ),
  62.     'wd-vit-large-tagger.v3': WaifuDiffusionInterrogator(
  63.         'WD ViT-Large Tagger v3',
  64.         repo_id='SmilingWolf/wd-vit-large-tagger-v3'
  65.     ),
  66.     'wd-eva02-large-tagger-v3': WaifuDiffusionInterrogator(
  67.         'WD EVA02-Large Tagger v3',
  68.         repo_id='SmilingWolf/wd-eva02-large-tagger-v3'
  69.     ),
  70.     'mld-caformer.dec-5-97527': MLDanbooruInterrogator(
  71.         'ML-Danbooru Caformer dec-5-97527',
  72.         repo_id='deepghs/ml-danbooru-onnx',
  73.         model_path='ml_caformer_m36_dec-5-97527.onnx'
  74.     ),
  75.     'mld-tresnetd.6-30000': MLDanbooruInterrogator(
  76.         'ML-Danbooru TResNet-D 6-30000',
  77.         repo_id='deepghs/ml-danbooru-onnx',
  78.         model_path='TResnet-D-FLq_ema_6-30000.onnx'
  79.     ),
  80. }
  81.  
  82.  
  83. def refresh_interrogators() -> List[str]:
  84.     """Refreshes the interrogators list"""
  85.     # load deepdanbooru project
  86.     ddp_path = shared.cmd_opts.deepdanbooru_projects_path
  87.     if ddp_path is None:
  88.         ddp_path = default_ddp_path
  89.     onnx_path = shared.cmd_opts.onnxtagger_path
  90.     if onnx_path is None:
  91.         onnx_path = default_onnx_path
  92.     os.makedirs(ddp_path, exist_ok=True)
  93.     os.makedirs(onnx_path, exist_ok=True)
  94.  
  95.     for path in os.scandir(ddp_path):
  96.         print(f"Scanning {path} as deepdanbooru project")
  97.         if not path.is_dir():
  98.             print(f"Warning: {path} is not a directory, skipped")
  99.             continue
  100.  
  101.         if not Path(path, 'project.json').is_file():
  102.             print(f"Warning: {path} has no project.json, skipped")
  103.             continue
  104.  
  105.         interrogators[path.name] = DeepDanbooruInterrogator(path.name, path)
  106.     # scan for onnx models as well
  107.     for path in os.scandir(onnx_path):
  108.         print(f"Scanning {path} as onnx model")
  109.         if not path.is_dir():
  110.             print(f"Warning: {path} is not a directory, skipped")
  111.             continue
  112.  
  113.         onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')]
  114.         if len(onnx_files) != 1:
  115.             print(f"Warning: {path} requires exactly one .onnx model, skipped")
  116.             continue
  117.         local_path = Path(path, onnx_files[0].name)
  118.  
  119.         csv = [x for x in os.scandir(path) if x.name.endswith('.csv')]
  120.         if len(csv) == 0:
  121.             print(f"Warning: {path} has no selected tags .csv file, skipped")
  122.             continue
  123.  
  124.         def tag_select_csvs_up_front(k):
  125.             sum(-1 if t in k.name.lower() else 1 for t in ["tag", "select"])
  126.  
  127.         csv.sort(key=tag_select_csvs_up_front)
  128.         tags_path = Path(path, csv[0])
  129.  
  130.         if path.name not in interrogators:
  131.             if path.name == 'wd-v1-4-convnextv2-tagger-v2':
  132.                 interrogators[path.name] = WaifuDiffusionInterrogator(
  133.                     path.name,
  134.                     repo_id='SmilingWolf/SW-CV-ModelZoo',
  135.                     is_hf=False
  136.                 )
  137.             elif path.name == 'Z3D-E621-Convnext':
  138.                 interrogators[path.name] = WaifuDiffusionInterrogator(
  139.                     'Z3D-E621-Convnext', is_hf=False)
  140.             else:
  141.                 raise NotImplementedError(f"Add {path.name} resolution similar"
  142.                                           "to above here")
  143.  
  144.         interrogators[path.name].local_model = str(local_path)
  145.         interrogators[path.name].local_tags = str(tags_path)
  146.  
  147.     return sorted(interrogators.keys())
  148.  
  149.  
  150. def split_str(string: str, separator=',') -> List[str]:
  151.     return [x.strip() for x in string.split(separator) if x]
  152.  
Advertisement
Add Comment
Please, Sign In to add comment