Guest User

ComfUI Unet Extractor

a guest
Aug 17th, 2024
223
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.02 KB | Software | 0 0
  1. import os
  2. from safetensors import safe_open
  3. from safetensors.torch import save_file
  4.  
  5. class UnetExtract:
  6.     @classmethod
  7.     def INPUT_TYPES(cls):
  8.         base_dir = os.path.dirname(os.path.abspath(__file__))
  9.         checkpoints_dir = os.path.join(base_dir, "..", "models", "checkpoints")
  10.        
  11.         safetensors_files = []
  12.         for root, dirs, files in os.walk(checkpoints_dir):
  13.             safetensors_files.extend([os.path.join(root, f) for f in files if f.endswith('.safetensors')])
  14.        
  15.         return {
  16.             "required": {
  17.                 "safetensors_file": (safetensors_files, )  # Use a list of choices directly
  18.             }
  19.         }
  20.  
  21.     RETURN_TYPES = ("STRING",)
  22.     RETURN_NAMES = ("message",)
  23.     CATEGORY = "UNET"
  24.     FUNCTION = "extract_unet"
  25.  
  26.     def extract_unet(self, safetensors_file):
  27.         base_dir = os.path.dirname(os.path.abspath(__file__))
  28.         unet_dir = os.path.join(base_dir, "..", "models", "unet")
  29.  
  30.         file_path = safetensors_file
  31.         unet_path = os.path.join(unet_dir, os.path.basename(safetensors_file))
  32.  
  33.         if not os.path.exists(file_path):
  34.             return f"File {file_path} does not exist."
  35.  
  36.         if os.path.exists(unet_path):
  37.             print(f"File {unet_path} already exists. Overwrite? (yes/no): ")
  38.             user_input = input()
  39.             if user_input.lower() != 'yes':
  40.                 return f"File {unet_path} was not overwritten. Check the console for details."
  41.  
  42.         with safe_open(file_path, framework="pt", device="cpu") as f:
  43.             tensor_names = f.keys()
  44.             diffusion_tensors = {name: f.get_tensor(name) for name in tensor_names if "diffusion_model" in name}
  45.            
  46.             if not diffusion_tensors:
  47.                 return f"No tensors with 'diffusion_model' found in {file_path}."
  48.            
  49.             save_file(diffusion_tensors, unet_path)
  50.        
  51.         return f"Tensors with 'diffusion_model' saved to {unet_path}"
  52.  
  53. NODE_CLASS_MAPPINGS = {
  54.     "UnetExtract": UnetExtract,
  55. }
  56.  
Tags: ComfyUI
Advertisement
Add Comment
Please, Sign In to add comment