Advertisement
Guest User

loratags.py

a guest
Jun 17th, 2023
333
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.60 KB | None | 0 0
  1. import json
  2. import re
  3.  
  4. from argparse import ArgumentParser
  5. from pathlib import Path
  6.  
  7. NONESCPAREN_REGEX = re.compile(r'(?<!\\)[()]')
  8.  
  9. def parse_args():
  10.     parser = ArgumentParser(description="get a prompt from a lora's built-in tags")
  11.     parser.add_argument("lorafile", type=Path, help="path to the lora file in safetensors format")
  12.     parser.add_argument("-countmin", type=int, default=10, help="minimum tag count for a tag to be included")
  13.     parser.add_argument("-tagmax", type=int, default=30, help="maximum amount of tags to include in the prompt")
  14.     return parser.parse_args()
  15.  
  16. def read_metadata_from_safetensors(filename) -> dict:
  17.     with open(filename, mode="rb") as file:
  18.         metadata_len = file.read(8)
  19.         metadata_len = int.from_bytes(metadata_len, "little")
  20.         json_start = file.read(2)
  21.  
  22.         assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
  23.  
  24.         json_data = json_start + file.read(metadata_len-2)
  25.         json_obj = json.loads(json_data)
  26.  
  27.         res = {}
  28.         for k, v in json_obj.get("__metadata__", {}).items():
  29.             res[k] = v
  30.             if isinstance(v, str) and v[0:1] == "{":
  31.                 try:
  32.                     res[k] = json.loads(v)
  33.                 except Exception:
  34.                     pass
  35.         return res
  36.  
  37. def main():
  38.     args = parse_args()
  39.     lorafile = args.lorafile.resolve()
  40.     if not lorafile.exists():
  41.         print(f"'{lorafile.name}' does not exist")
  42.         exit(1)
  43.     if not lorafile.is_file():
  44.         print(f"'{lorafile.name}' is not a file")
  45.         exit(1)
  46.  
  47.     try:
  48.         metadata = read_metadata_from_safetensors(lorafile)
  49.     except Exception as e:
  50.         print(e)
  51.         exit(1)
  52.  
  53.     tag_frequency: dict = metadata.get("ss_tag_frequency")
  54.     if tag_frequency is None:
  55.         print("lora has no tag metadata.")
  56.         exit(1)
  57.  
  58.     # In the case of multiple groups, join all their tags together
  59.     all_tags = {}
  60.     for tags in tag_frequency.values():
  61.         tags: dict
  62.         if not all_tags:
  63.             all_tags = tags.copy()
  64.         else:
  65.             for tag, count in tags.items():
  66.                 try:
  67.                     all_tags[tag] += count
  68.                 except KeyError:
  69.                     all_tags[tag] = count
  70.  
  71.     tags = sorted(all_tags.items(), key=lambda item: item[1], reverse=True)
  72.     tags = list(filter(lambda item: item[1] >= args.countmin, tags))
  73.     tags = [NONESCPAREN_REGEX.sub(r'\\\g<0>', item[0]).strip() for item in tags[0:args.tagmax]]
  74.  
  75.     print(', '.join(tags))
  76.  
  77.  
  78. if __name__ == "__main__":
  79.     main()
  80.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement