Advertisement
mmquant

nodes_torch_compile_lora_safe.py

Jun 4th, 2025 (edited)
149
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.82 KB | Source Code | 0 0
  1. # Torch 2.3-ready “LoRA-safe” compile node for ComfyUI
  2. # ---------------------------------------------------
  3. # – Fixes accidental `mode=None`
  4. # – Ensures .train() / .eval() are forwarded to the real UNet
  5. # – Handles concurrent first-calls safely with a lock
  6. # – Performs a quick backend sanity-check so users get friendly errors
  7. #
  8. # Drop this file in your ComfyUI/custom_nodes folder and restart.
  9.  
  10. import torch
  11. import torch.nn as nn
  12. import threading
  13.  
  14.  
  15. # --------------------------------------------------------------------- #
  16. #  Helper – transparent wrapper that compiles itself on first forward()
  17. # --------------------------------------------------------------------- #
  18. class _LazyCompiled(nn.Module):
  19.     """
  20.    Wraps any nn.Module (e.g. a UNet or a single transformer block)
  21.    and replaces its forward pass with torch.compile **lazily** on the
  22.    first call.  All other attributes / helpers are proxied so that the
  23.    module behaves like the original one.
  24.    """
  25.     def __init__(self, module: nn.Module, **compile_kw):
  26.         super().__init__()
  27.         self._orig        = module
  28.         self._compiled    = None
  29.         self._compile_kw  = compile_kw
  30.         self._compile_lock = threading.Lock()   # avoid double-compile races
  31.  
  32.     # ---------- Attribute & module proxies -------------------------------- #
  33.     def __getattr__(self, name):
  34.         if name in {"_orig", "_compiled", "_compile_kw", "_compile_lock"}:
  35.             return super().__getattr__(name)
  36.         return getattr(self._orig, name)        # delegate to real module
  37.  
  38.     def modules(self):         return self._orig.modules()
  39.     def children(self):        return self._orig.children()
  40.     def named_modules(self, *a, **k): return self._orig.named_modules(*a, **k)
  41.     def state_dict(self,  *a, **k): return self._orig.state_dict(*a, **k)
  42.  
  43.     # dtype / device queries that ComfyUI calls during sampler setup
  44.     @property
  45.     def dtype(self):  return self._orig.dtype
  46.     @property
  47.     def device(self): return self._orig.device
  48.  
  49.     # ---------- Training / device helpers --------------------------------- #
  50.     def train(self, mode: bool = True):
  51.         # keep both wrapper *and* wrapped module in sync
  52.         self._orig.train(mode)
  53.         return super().train(mode)
  54.  
  55.     def eval(self):         # convenience alias
  56.         return self.train(False)
  57.  
  58.     def to(self, *args, **kwargs):
  59.         self._orig.to(*args, **kwargs)
  60.         return super().to(*args, **kwargs)
  61.  
  62.     # ---------- first call → compile -------------------------------------- #
  63.     def forward(self, *args, **kwargs):
  64.         if self._compiled is None:
  65.             # Only one thread actually compiles
  66.             with self._compile_lock:
  67.                 if self._compiled is None:
  68.                     self._compiled = torch.compile(self._orig, **self._compile_kw)
  69.         return self._compiled(*args, **kwargs)
  70.  
  71.  
  72. # --------------------------------------------------------------------- #
  73. #                       ComfyUI node definition
  74. # --------------------------------------------------------------------- #
  75. class TorchCompileModel_LoRASafe:
  76.     """LoRA-safe torch.compile with extra options."""
  77.  
  78.     @classmethod
  79.     def INPUT_TYPES(cls):
  80.         return {
  81.             "required": {
  82.                 "model": ("MODEL",),
  83.  
  84.                 # same four knobs as the stock node
  85.                 "backend": (["inductor", "cudagraphs", "nvfuser"],),
  86.                 "mode":    (["default", "reduce-overhead", "max-autotune"],),
  87.                 "fullgraph": ("BOOLEAN", {"default": False}),
  88.                 "dynamic":   ("BOOLEAN", {"default": False}),
  89.  
  90.                 # replicate compile_transformer_block_only
  91.                 "compile_transformer_only": (
  92.                     "BOOLEAN",
  93.                     {"default": False,
  94.                      "tooltip":
  95.                      "True → compile each transformer block lazily; "
  96.                      "False → compile whole UNet once"}
  97.                 ),
  98.             }
  99.         }
  100.  
  101.     RETURN_TYPES = ("MODEL",)
  102.     FUNCTION     = "patch"
  103.     CATEGORY     = "model/optimisation 🛠️"
  104.     EXPERIMENTAL = True
  105.  
  106.     # ----------------------------------------------------------------- #
  107.     @staticmethod
  108.     def _check_backend(backend: str):
  109.         """Raise a friendly error if the chosen backend cannot run."""
  110.         if backend == "nvfuser" and not torch.cuda.is_available():
  111.             raise ValueError("nvfuser backend requires a CUDA GPU. "
  112.                              "Select 'inductor' instead.")
  113.         if backend == "cudagraphs":
  114.             if not torch.cuda.is_available():
  115.                 raise ValueError("cudagraphs backend needs a CUDA GPU.")
  116.             cap = torch.cuda.get_device_capability()
  117.             if cap[0] < 7:
  118.                 raise ValueError("cudagraphs works reliably on GPUs with "
  119.                                  "compute capability 7.0 or higher "
  120.                                  f"(yours is {cap[0]}.{cap[1]}).")
  121.  
  122.     # ----------------------------------------------------------------- #
  123.     def patch(self,
  124.               model, backend, mode,
  125.               fullgraph, dynamic,
  126.               compile_transformer_only):
  127.  
  128.         # backend sanity-check before we go any further
  129.         self._check_backend(backend)
  130.  
  131.         m  = model.clone()                              # don’t mutate input
  132.         dm = m.get_model_object("diffusion_model")      # real UNet
  133.  
  134.         # build compile() kwargs
  135.         compile_kw = dict(
  136.             backend   = backend,
  137.             fullgraph = fullgraph,
  138.             dynamic   = dynamic,
  139.         )
  140.         if mode != "default":                           # fix for mode=None
  141.             compile_kw["mode"] = mode
  142.  
  143.         # ---------- A) whole-UNet compile (default) ------------------- #
  144.         if not compile_transformer_only:
  145.             m.add_object_patch(
  146.                 "diffusion_model",
  147.                 _LazyCompiled(dm, **compile_kw)
  148.             )
  149.             return (m,)
  150.  
  151.         # ---------- B) transformer-only compile ----------------------- #
  152.         if hasattr(dm, "transformer_blocks"):
  153.             for i, blk in enumerate(dm.transformer_blocks):
  154.                 m.add_object_patch(
  155.                     f"diffusion_model.transformer_blocks.{i}",
  156.                     _LazyCompiled(blk, **compile_kw)
  157.                 )
  158.         else:  # fallback – compile whole UNet
  159.             m.add_object_patch(
  160.                 "diffusion_model",
  161.                 _LazyCompiled(dm, **compile_kw)
  162.             )
  163.         return (m,)
  164.  
  165.  
  166. # --------------------------------------------------------------------- #
  167. #                     ComfyUI registration shim
  168. # --------------------------------------------------------------------- #
  169. NODE_CLASS_MAPPINGS = {
  170.     "TorchCompileModel_LoRASafe": TorchCompileModel_LoRASafe,
  171. }
  172.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement