Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
- clip_sd = None
- vae_sd = None
- load_models = [model]
- if clip is not None:
- load_models.append(clip.load_model())
- clip_sd = clip.get_sd()
- if vae is not None:
- vae_sd = vae.get_sd()
- model_management.load_models_gpu(load_models, force_patch_weights=True)
- clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
- sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
- for k in extra_keys:
- sd[k] = extra_keys[k]
- for k in sd:
- t = sd[k]
- if not t.is_contiguous():
- sd[k] = t.contiguous()
- comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
- class UnetSave:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- @classmethod
- def INPUT_TYPES(s):
- return {"required": { "model": ("MODEL",),
- "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
- RETURN_TYPES = ()
- FUNCTION = "save"
- OUTPUT_NODE = True
- CATEGORY = "advanced/model_merging"
- def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
- save_checkpoint(model, clip=None, vae=None, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
- return {}
- NODE_CLASS_MAPPINGS = {
- "ModelMergeSimple": ModelMergeSimple,
- "ModelMergeBlocks": ModelMergeBlocks,
- "ModelMergeSubtract": ModelSubtract,
- "ModelMergeAdd": ModelAdd,
- "CheckpointSave": CheckpointSave,
- "UnetSave": UnetSave,
- "CLIPMergeSimple": CLIPMergeSimple,
- "CLIPMergeSubtract": CLIPSubtract,
- "CLIPMergeAdd": CLIPAdd,
- "CLIPSave": CLIPSave,
- "VAESave": VAESave,
- }
Advertisement
Add Comment
Please, Sign In to add comment