Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from onediff.infer_compiler import compile
- import os
- class TorchCompileModel:
- @classmethod
- def INPUT_TYPES(s):
- return {"required": { "model": ("MODEL",),
- "backend": (["inductor", "cudagraphs"],),
- }}
- RETURN_TYPES = ("MODEL",)
- FUNCTION = "patch"
- CATEGORY = "_for_testing"
- EXPERIMENTAL = True
- def patch(self, model, backend):
- # os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' #may not be needed
- # os.environ['NEXFORT_FUSE_TIMESTEP_EMBEDDING'] = '0' #may not be needed
- #options = '{"mode": "O3"}' # mode can be O2 or O3
- options= {"mode": "max-optimize:max-autotune:max-autotune:cache-all", "options": {"inductor.optimize_linear_epilogue": True, "triton.fuse_attention_allow_fp16_reduction": True}}
- m = model.clone()
- #m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
- m.add_object_patch("diffusion_model", compile(m.get_model_object("diffusion_model"), backend="nexfort", options=options))
- return (m, )
- NODE_CLASS_MAPPINGS = {
- "TorchCompileModel": TorchCompileModel,
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement