Guest User

unetsave

a guest
Aug 17th, 2024
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.94 KB | None | 0 0
  1. def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
  2. clip_sd = None
  3. vae_sd = None
  4. load_models = [model]
  5. if clip is not None:
  6. load_models.append(clip.load_model())
  7. clip_sd = clip.get_sd()
  8.  
  9. if vae is not None:
  10. vae_sd = vae.get_sd()
  11.  
  12. model_management.load_models_gpu(load_models, force_patch_weights=True)
  13. clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
  14. sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
  15. for k in extra_keys:
  16. sd[k] = extra_keys[k]
  17.  
  18. for k in sd:
  19. t = sd[k]
  20. if not t.is_contiguous():
  21. sd[k] = t.contiguous()
  22.  
  23. comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
  24.  
  25.  
  26. class UnetSave:
  27. def __init__(self):
  28. self.output_dir = folder_paths.get_output_directory()
  29.  
  30. @classmethod
  31. def INPUT_TYPES(s):
  32. return {"required": { "model": ("MODEL",),
  33. "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
  34. "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
  35. RETURN_TYPES = ()
  36. FUNCTION = "save"
  37. OUTPUT_NODE = True
  38.  
  39. CATEGORY = "advanced/model_merging"
  40.  
  41. def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
  42. save_checkpoint(model, clip=None, vae=None, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
  43. return {}
  44.  
  45. NODE_CLASS_MAPPINGS = {
  46. "ModelMergeSimple": ModelMergeSimple,
  47. "ModelMergeBlocks": ModelMergeBlocks,
  48. "ModelMergeSubtract": ModelSubtract,
  49. "ModelMergeAdd": ModelAdd,
  50. "CheckpointSave": CheckpointSave,
  51. "UnetSave": UnetSave,
  52. "CLIPMergeSimple": CLIPMergeSimple,
  53. "CLIPMergeSubtract": CLIPSubtract,
  54. "CLIPMergeAdd": CLIPAdd,
  55. "CLIPSave": CLIPSave,
  56. "VAESave": VAESave,
  57. }
Advertisement
Add Comment
Please, Sign In to add comment