Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- from safetensors import safe_open
- from safetensors.torch import save_file
- class UnetExtract:
- @classmethod
- def INPUT_TYPES(cls):
- base_dir = os.path.dirname(os.path.abspath(__file__))
- checkpoints_dir = os.path.join(base_dir, "..", "models", "checkpoints")
- safetensors_files = []
- for root, dirs, files in os.walk(checkpoints_dir):
- safetensors_files.extend([os.path.join(root, f) for f in files if f.endswith('.safetensors')])
- return {
- "required": {
- "safetensors_file": (safetensors_files, ) # Use a list of choices directly
- }
- }
- RETURN_TYPES = ("STRING",)
- RETURN_NAMES = ("message",)
- CATEGORY = "UNET"
- FUNCTION = "extract_unet"
- def extract_unet(self, safetensors_file):
- base_dir = os.path.dirname(os.path.abspath(__file__))
- unet_dir = os.path.join(base_dir, "..", "models", "unet")
- file_path = safetensors_file
- unet_path = os.path.join(unet_dir, os.path.basename(safetensors_file))
- if not os.path.exists(file_path):
- return f"File {file_path} does not exist."
- if os.path.exists(unet_path):
- print(f"File {unet_path} already exists. Overwrite? (yes/no): ")
- user_input = input()
- if user_input.lower() != 'yes':
- return f"File {unet_path} was not overwritten. Check the console for details."
- with safe_open(file_path, framework="pt", device="cpu") as f:
- tensor_names = f.keys()
- diffusion_tensors = {name: f.get_tensor(name) for name in tensor_names if "diffusion_model" in name}
- if not diffusion_tensors:
- return f"No tensors with 'diffusion_model' found in {file_path}."
- save_file(diffusion_tensors, unet_path)
- return f"Tensors with 'diffusion_model' saved to {unet_path}"
- NODE_CLASS_MAPPINGS = {
- "UnetExtract": UnetExtract,
- }
Advertisement
Add Comment
Please, Sign In to add comment