Advertisement
Guest User

node

a guest
Oct 12th, 2024
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.21 KB | None | 0 0
  1. import torch
  2. from onediff.infer_compiler import compile
  3. import os
  4.  
  5. class TorchCompileModel:
  6. @classmethod
  7. def INPUT_TYPES(s):
  8. return {"required": { "model": ("MODEL",),
  9. "backend": (["inductor", "cudagraphs"],),
  10. }}
  11. RETURN_TYPES = ("MODEL",)
  12. FUNCTION = "patch"
  13.  
  14. CATEGORY = "_for_testing"
  15. EXPERIMENTAL = True
  16.  
  17. def patch(self, model, backend):
  18. # os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' #may not be needed
  19. # os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0' #may not be needed
  20. #options = '{"mode": "O3"}' # mode can be O2 or O3
  21. options= {"mode": "max-optimize:max-autotune:max-autotune:cache-all", "options": {"inductor.optimize_linear_epilogue": True, "triton.fuse_attention_allow_fp16_reduction": True}}
  22. m = model.clone()
  23. #m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
  24. m.add_object_patch("diffusion_model", compile(m.get_model_object("diffusion_model"), backend="nexfort", options=options))
  25. return (m, )
  26.  
  27. NODE_CLASS_MAPPINGS = {
  28. "TorchCompileModel": TorchCompileModel,
  29. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement