Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # === ComfyUI WanVideoWrapper — Unified MultiTalk Patcher v3 ==================
- # Purpose:
- # Enable MultiTalk on Wan 2.2 I2V 14B GGUF exports by grafting an audio
- # projection + audio cross-attention path at model load time, and make
- # wav2vec ingestion robust. Safetensors builds are unaffected.
- #
- # Key fixes:
- # • Provide/attach audio_proj + audio_cross_attn if missing (GGUF gap).
- # • Apply grafts in BOTH single-GPU and Multi-GPU loaders.
- # • Harden wav2vec path (seq_len signature diffs, optional torchaudio, unify
- # hidden_states and resample to video_length).
- #
- # Safety:
- # • Idempotent; backups written once as *.mtbak_unified_v3_before.
- # • Clear logs on graft execution to simplify troubleshooting.
- # ============================================================================
- set -euo pipefail # Fail fast, treat unset vars as errors
- # --- Paths (override via env if your layout differs) -------------------------
- WRAP_DIR="${WRAP_DIR:-/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper}"
- MGPU_DIR="${MGPU_DIR:-/workspace/ComfyUI/custom_nodes/comfyui-multigpu}"
- MOD_DIR="$WRAP_DIR/wanvideo/modules" # Where the shim will live
- MULTITALK_DIR="$WRAP_DIR/multitalk" # MultiTalk nodes location
- LOADER_PY="$WRAP_DIR/nodes_model_loading.py" # Single-GPU loader
- MGPU_WANVIDEO_PY="$MGPU_DIR/wanvideo.py" # Multi-GPU loader (if present)
- NODES_PY="$MULTITALK_DIR/nodes.py" # MultiTalk node definitions
- AP_SHIM="$MOD_DIR/audio_proj_shim.py" # Shim target path
- # --- Sanity checks -----------------------------------------------------------
- command -v python3 >/dev/null || { echo "python3 is required"; exit 1; }
- test -d "$WRAP_DIR" || { echo "WanVideoWrapper not found: $WRAP_DIR"; exit 1; }
- test -f "$LOADER_PY" || { echo "Missing: $LOADER_PY"; exit 1; }
- test -f "$NODES_PY" || { echo "Missing: $NODES_PY"; exit 1; }
- mkdir -p "$MOD_DIR" "$MULTITALK_DIR"
- # --- [1/4] Write/update the runtime shim that performs the graft -------------
- # The shim:
- # • Unwraps the underlying Wan model from the patcher.
- # • Ensures 'audio_proj' exists (uses pack proj if provided; Identity otherwise).
- # • Adds 'audio_cross_attn' and 'norm_x' per block if missing.
- # • Loads any provided state dict softly and aligns device/dtype.
- # • Marks _mt_grafted to avoid repeated work across reloads.
- echo "[1/4] Ensure shim: $AP_SHIM"
- python3 - <<'PY'
- import os, io
- dst = os.environ.get("AP_SHIM", "/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper/wanvideo/modules/audio_proj_shim.py")
- src = r'''# Auto-generated by unified MultiTalk patcher v3
- # Role: Provide a late-binding graft for MultiTalk audio ingestion when
- # Wan GGUF exports lack audio_proj + audio_cross_attn.
- import torch
- from torch import nn
- from ...multitalk.multitalk import SingleStreamMultiAttention
- from .model import WanLayerNorm, WanRMSNorm
- __all__ = ["setup_multitalk_for_model"]
- def _unwrap_model(patcher_or_model):
- """Locate the underlying Wan model regardless of wrapper/patcher type."""
- m = getattr(patcher_or_model, "model", None)
- if m is None and hasattr(patcher_or_model, "diffusion_model"):
- m = patcher_or_model
- if m is None:
- for k in ("_model", "_unet"):
- m = getattr(patcher_or_model, k, None)
- if m is not None:
- break
- if m is None:
- raise RuntimeError("Could not unwrap model from patcher")
- return m
- def setup_multitalk_for_model(patcher, multitalk_model=None):
- """Attach audio projection + cross-attn only if missing (idempotent)."""
- model = _unwrap_model(patcher)
- # Skip if we already grafted this instance
- if getattr(model, "_mt_grafted", False):
- return patcher
- # If blocks already have audio cross-attn, assume prior integration
- try:
- if hasattr(model.diffusion_model.blocks[0], "audio_cross_attn"):
- model._mt_grafted = True
- return patcher
- except Exception:
- pass
- # Need a MultiTalk pack to source dimensions/weights when available
- if multitalk_model is None:
- return patcher
- # Prefer a projection supplied by the pack (proj_model); fall back to Identity
- proj = multitalk_model.get("proj_model", None) if isinstance(multitalk_model, dict) else None
- if not hasattr(model, "audio_proj"):
- if proj is None:
- # Keep it conservative; dimension discovery happens downstream
- proj = nn.Identity()
- model.audio_proj = proj
- # Attach per-block cross-attn + norm if absent
- try:
- blocks = model.diffusion_model.blocks
- except Exception as e:
- raise RuntimeError(f"[MultiTalk] Unexpected WanModel structure: {e}")
- for blk in blocks:
- dim = getattr(blk, "dim", None) or getattr(blk, "hidden_size", None)
- if dim is None:
- raise RuntimeError("[MultiTalk] block.dim missing")
- num_heads = getattr(blk, "num_heads", 8)
- eps = getattr(blk, "eps", 1e-6)
- attention_mode = getattr(blk, "attention_mode", "sdpa")
- if not hasattr(blk, "audio_cross_attn"):
- blk.audio_cross_attn = SingleStreamMultiAttention(
- dim=dim, encoder_hidden_states_dim=768, num_heads=num_heads,
- qk_norm=False, qkv_bias=True, eps=eps, norm_layer=WanRMSNorm,
- attention_mode=attention_mode,
- )
- if not hasattr(blk, "norm_x"):
- blk.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True)
- # Best-effort soft load of any provided state dict
- try:
- sd = multitalk_model.get("sd", {}) if isinstance(multitalk_model, dict) else {}
- model.load_state_dict(sd, strict=False)
- except Exception:
- pass
- # Match device/dtype of the main model params
- try:
- ref = next(model.diffusion_model.parameters())
- model.to(ref.device, dtype=ref.dtype)
- except Exception:
- pass
- model._mt_grafted = True
- print("[MultiTalk] Grafted audio_proj + audio_cross_attn at load time (shim).")
- return patcher
- '''
- os.makedirs(os.path.dirname(dst), exist_ok=True)
- old = io.open(dst, "r", encoding="utf-8", errors="ignore").read() if os.path.exists(dst) else None
- if old != src:
- io.open(dst, "w", encoding="utf-8").write(src)
- print(f"[shim] wrote {dst}")
- else:
- print(f"[shim] up-to-date {dst}")
- PY
- # --- [2/4] Patch the single-GPU loader to call the shim at the right time ----
- # Also includes minor hygiene (duplicate class collapse / INPUT_TYPES) and
- # injects the graft just before the final return in loadmodel(...).
- echo "[2/4] Patch WanVideoWrapper loader + LoRA class (nodes_model_loading.py)"
- python3 - <<'PY'
- import io, re, sys, os
- PYFILE = "/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper/nodes_model_loading.py"
- src = io.open(PYFILE, "r", encoding="utf-8", errors="ignore").read()
- io.open(PYFILE + ".mtbak_unified_v3_before", "w", encoding="utf-8").write(src)
- print("[backup]", PYFILE + ".mtbak_unified_v3_before")
- changed = 0
- # Ensure we can import our shim locally
- if "setup_multitalk_for_model" not in src:
- ins = "from .wanvideo.modules.audio_proj_shim import setup_multitalk_for_model\n"
- m = re.search(r'(?m)^from\s+\.utils\s+import\s+log\s*$', src)
- if m:
- src = src[:m.end()] + "\n" + ins + src[m.end():]
- else:
- top = re.search(r'(?m)^(?:from|import)\s+[^\n]+\n(?:\s*(?:from|import)\s+[^\n]+\n)*', src)
- src = src[:top.end()] + ins + src[top.end():] if top else ins + src
- changed += 1
- print("[import] added loader shim import")
- # Utilities to find / edit class blocks
- def find_class_block(code, name):
- m = re.search(r'(?m)^class\s+'+re.escape(name)+r'\s*:\s*$', code)
- if not m: return None
- s = m.start(); tail = code[m.end():]
- n = re.search(r'(?m)^(class\s+\w+|#region\b|NODE_CLASS_MAPPINGS\s*=)', tail)
- e = m.end() + (n.start() if n else len(tail))
- return (s,e)
- def class_body_indent(txt):
- for L in txt.splitlines()[1:]:
- if L.strip() and not L.lstrip().startswith("#"):
- return re.match(r'^(\s*)', L).group(1)
- return " "
- # Collapse duplicate WanVideoSetLoRAs definitions (defensive)
- blocks = [ (m.start(), None) for m in re.finditer(r'(?m)^class\s+WanVideoSetLoRAs\s*:\s*$', src) ]
- if len(blocks) > 1:
- positions = []
- for m in re.finditer(r'(?m)^class\s+WanVideoSetLoRAs\s*:\s*$', src):
- s = m.start(); tail = src[m.end():]
- n = re.search(r'(?m)^(class\s+\w+|#region\b|NODE_CLASS_MAPPINGS\s*=)', tail)
- e = m.end() + (n.start() if n else len(tail))
- positions.append((s,e))
- print(f"[merge] Found {len(positions)} WanVideoSetLoRAs blocks; collapsing.")
- first_s, first_e = positions[0]
- first = src[first_s:first_e]
- bi = class_body_indent(first)
- # Try to salvage a 'setlora' implementation from later copies
- method = None
- for s,e in positions[1:]:
- blk = src[s:e]
- m = re.search(r'(?m)^\s*def\s+setlora\s*\(', blk)
- if m:
- method = blk[m.start():].rstrip() + "\n"
- def_i = re.match(r'^(\s*)', method).group(1)
- method = "\n".join([L[len(def_i):] if L.startswith(def_i) else L for L in method.splitlines()]) + "\n"
- break
- # Drop the later duplicates
- for s,e in reversed(positions[1:]):
- src = src[:s] + src[e:]
- # Re-insert setlora if missing
- s,e = find_class_block(src, "WanVideoSetLoRAs")
- if s is not None and method and re.search(r'(?m)^\s*def\s+setlora\s*\(', src[s:e]) is None:
- insert_at = e
- src = src[:insert_at] + ("\n" + bi + method.replace("\n", "\n"+bi)).rstrip() + "\n" + src[insert_at:]
- changed += 1
- # Ensure INPUT_TYPES is a @classmethod
- block = find_class_block(src, "WanVideoSetLoRAs")
- if block:
- s,e = block
- klass = src[s:e]
- if re.search(r'(?m)^\s*@classmethod\s*\n\s*def\s+INPUT_TYPES\s*\(', klass) is None:
- if re.search(r'(?m)^\s*def\s+INPUT_TYPES\s*\(', klass):
- new = re.sub(r'(?m)^(\s*)def\s+INPUT_TYPES\s*\(', r'\1@classmethod\n\1def INPUT_TYPES(', klass, count=1)
- src = src[:s] + new + src[e:]
- else:
- bi = class_body_indent(klass)
- ins = (
- f"{bi}@classmethod\n"
- f"{bi}def INPUT_TYPES(s):\n"
- f"{bi} return {{\n"
- f"{bi} 'required': {{ 'model': ('WANVIDEOMODEL', ), }},\n"
- f"{bi} 'optional': {{ 'lora': ('WANVIDLORA', ), }},\n"
- f"{bi} }}\n\n"
- )
- line_end = s + src[s:e].find("\n") + 1
- src = src[:line_end] + ins + src[line_end:]
- changed += 1
- print("[fix] ensured INPUT_TYPES in WanVideoSetLoRAs")
- # Inject the graft just before the final return of loadmodel(...)
- def inject_graft(code, class_name):
- blk = find_class_block(code, class_name)
- if not blk:
- print(f"[loader] class {class_name} not found");
- return code, 0
- s,e = blk
- klass = code[s:e]
- mdef = re.search(r'(?m)^\s*def\s+loadmodel\s*\(', klass)
- if not mdef:
- print(f"[loader] def loadmodel(...) not found in {class_name}")
- return code, 0
- dstart = s + mdef.start()
- m_indent = re.search(r'(?m)^(?P<i>\s*)def\s+loadmodel\s*\(', code[dstart:]).group('i')
- body_i = m_indent + " "
- method_text = code[dstart:e]
- if "setup_multitalk_for_model(" in method_text:
- print(f"[loader] graft already present in {class_name}")
- return code, 0
- last = None
- for m in re.finditer(r'(?m)^\s*return\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*,\s*\)\s*$', method_text):
- last = m
- var = "patcher"
- if last:
- var = last.group(1)
- graft = (
- f"{body_i}# ---- MultiTalk graft (loader-time) ----\n"
- f"{body_i}if 'multitalk_model' in locals() and multitalk_model is not None:\n"
- f"{body_i} _need = False\n"
- f"{body_i} try:\n"
- f"{body_i} _need = (not hasattr({var}.model if hasattr({var}, 'model') else {var}, 'audio_proj') or \n"
- f"{body_i} not hasattr(({var}.model if hasattr({var}, 'model') else {var}).diffusion_model.blocks[0], 'audio_cross_attn'))\n"
- f"{body_i} except Exception:\n"
- f"{body_i} _need = True\n"
- f"{body_i} if _need:\n"
- f"{body_i} try:\n"
- f"{body_i} {var} = setup_multitalk_for_model({var}, multitalk_model)\n"
- f"{body_i} try:\n"
- f"{body_i} log.info('[MultiTalk] Grafted audio layers in {class_name}.loadmodel')\n"
- f"{body_i} except Exception:\n"
- f"{body_i} print('[MultiTalk] Grafted audio layers in {class_name}.loadmodel')\n"
- f"{body_i} except Exception as _e:\n"
- f"{body_i} try:\n"
- f"{body_i} log.warning(f'[MultiTalk] Graft failed in {class_name}: {{_e}}')\n"
- f"{body_i} except Exception:\n"
- f"{body_i} print(f'[MultiTalk] Graft failed in {class_name}: {{_e}}')\n"
- f"{body_i}return ({var},)\n"
- )
- rel = dstart
- abs_s = rel + last.start()
- abs_e = rel + last.end()
- code = code[:abs_s] + graft + code[abs_e:]
- return code, 1
- return code, 0
- src, ch = inject_graft(src, "WanVideoModelLoader"); changed += ch
- if changed:
- io.open(PYFILE, "w", encoding="utf-8").write(src)
- print("[write] updated", PYFILE, "changes:", changed)
- # Syntax check for safety
- try:
- compile(src, PYFILE, "exec")
- print("[compile] OK:", PYFILE)
- except SyntaxError as e:
- print("[compile] FAILED at line", e.lineno, ":", e.msg)
- sys.exit(2)
- PY
- # --- [3/4] Patch the Multi-GPU loader (if present) --------------------------
- # Adds a small helper to import the shim even with hyphenated package names,
- # and injects the same loader-time graft as single-GPU.
- echo "[3/4] Patch MultiGPU loader(s) (if present): $MGPU_WANVIDEO_PY"
- if [ -f "$MGPU_WANVIDEO_PY" ]; then
- python3 - <<'PY'
- import io, re, sys, os, textwrap
- PYFILE = "/workspace/ComfyUI/custom_nodes/comfyui-multigpu/wanvideo.py"
- src = io.open(PYFILE, "r", encoding="utf-8", errors="ignore").read()
- io.open(PYFILE + ".mtbak_unified_v3_before", "w", encoding="utf-8").write(src)
- print("[backup]", PYFILE + ".mtbak_unified_v3_before")
- changed = 0
- # Helper to import shim directly from file path (works around package name)
- if "def _mt_load_setup_fn(" not in src and "setup_multitalk_for_model(" not in src:
- helper = textwrap.dedent(r'''
- # --- injected by MultiTalk patcher v3 (lazy-setup import + fallback) ---
- import importlib.util as _importlib_util, os as _os, sys as _sys
- def _mt_load_setup_fn():
- root = _os.path.abspath(_os.path.join(_os.path.dirname(__file__), "..", "ComfyUI-WanVideoWrapper", "wanvideo", "modules", "audio_proj_shim.py"))
- if _os.path.exists(root):
- spec = _importlib_util.spec_from_file_location("_mt_audio_proj_shim", root)
- if spec and spec.loader:
- mod = _importlib_util.module_from_spec(spec)
- try:
- spec.loader.exec_module(mod)
- return getattr(mod, "setup_multitalk_for_model", None)
- except Exception:
- pass
- return None
- _setup_multitalk_for_model = _mt_load_setup_fn()
- def setup_multitalk_for_model(patcher, multitalk_model=None):
- if _setup_multitalk_for_model is not None:
- return _setup_multitalk_for_model(patcher, multitalk_model)
- # Minimal fallback: ensure audio_proj exists so wrapper path can run
- try:
- model = getattr(patcher, "model", None) or patcher
- if not hasattr(model, "audio_proj"):
- import torch.nn as nn
- model.audio_proj = nn.Identity()
- # Tag the model to avoid repeated work
- setattr(model, "_mt_grafted", True)
- try:
- log.info("[MultiTalk] Grafted audio_proj (MGPU fallback).")
- except Exception:
- print("[MultiTalk] Grafted audio_proj (MGPU fallback).")
- except Exception as _e:
- try:
- log.warning(f"[MultiTalk] MGPU fallback graft failed: {_e}")
- except Exception:
- print(f"[MultiTalk] MGPU fallback graft failed: {_e}")
- return patcher
- # --- end injected helper ---
- ''').strip("\n") + "\n\n"
- m = re.search(r'(?ms)^(?:from|import)\s+[^\n]+\n(?:\s*(?:from|import)\s+[^\n]+\n)*', src)
- if m:
- src = src[:m.end()] + helper + src[m.end():]
- else:
- src = helper + src
- changed += 1
- print("[mgpu] injected helper setup loader")
- # Reuse class scanner and injection for MGPU loaders
- def find_class_block(code, name):
- m = re.search(r'(?m)^class\s+'+re.escape(name)+r'\s*:\s*$', code)
- if not m: return None
- s = m.start(); tail = code[m.end():]
- n = re.search(r'(?m)^(class\s+\w+|#region\b|NODE_CLASS_MAPPINGS\s*=)', tail)
- e = m.end() + (n.start() if n else len(tail))
- return (s,e)
- def inject_into_loader(code, class_name):
- blk = find_class_block(code, class_name)
- if not blk:
- return code, 0
- s,e = blk
- klass = code[s:e]
- mdef = re.search(r'(?m)^\s*def\s+loadmodel\s*\(', klass)
- if not mdef:
- return code, 0
- dstart = s + mdef.start()
- m_indent = re.search(r'(?m)^(?P<i>\s*)def\s+loadmodel\s*\(', code[dstart:]).group('i')
- body_i = m_indent + " "
- method_text = code[dstart:e]
- if "setup_multitalk_for_model(" in method_text:
- return code, 0
- last = None
- for m in re.finditer(r'(?m)^\s*return\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*,\s*\)\s*$', method_text):
- last = m
- var = "patcher"
- if last:
- var = last.group(1)
- graft = (
- f"{body_i}# ---- MultiTalk graft (MGPU loader-time) ----\n"
- f"{body_i}if 'multitalk_model' in locals() and multitalk_model is not None:\n"
- f"{body_i} _need = False\n"
- f"{body_i} try:\n"
- f"{body_i} _core = {var}.model if hasattr({var}, 'model') else {var}\n"
- f"{body_i} _need = (not hasattr(_core, 'audio_proj') or not hasattr(_core.diffusion_model.blocks[0], 'audio_cross_attn'))\n"
- f"{body_i} except Exception:\n"
- f"{body_i} _need = True\n"
- f"{body_i} if _need:\n"
- f"{body_i} try:\n"
- f"{body_i} {var} = setup_multitalk_for_model({var}, multitalk_model)\n"
- f"{body_i} try:\n"
- f"{body_i} log.info('[MultiTalk] Grafted audio layers in {class_name}.loadmodel (MGPU)')\n"
- f"{body_i} except Exception:\n"
- f"{body_i} print('[MultiTalk] Grafted audio layers in {class_name}.loadmodel (MGPU)')\n"
- f"{body_i} except Exception as _e:\n"
- f"{body_i} try:\n"
- f"{body_i} log.warning(f'[MultiTalk] Graft failed in {class_name} (MGPU): {{_e}}')\n"
- f"{body_i} except Exception:\n"
- f"{body_i} print(f'[MultiTalk] Graft failed in {class_name} (MGPU): {{_e}}')\n"
- f"{body_i}return ({var},)\n"
- )
- rel = dstart
- abs_s = rel + last.start()
- abs_e = rel + last.end()
- code = code[:abs_s] + graft + code[abs_e:]
- return code, 1
- return code, 0
- total = 0
- for cname in ("WanVideoModelLoader", "WanVideoModelLoader_2"):
- src, ch = inject_into_loader(src, cname); total += ch
- if ch: print(f"[mgpu] graft inserted in {cname}")
- if total:
- io.open(PYFILE, "w", encoding="utf-8").write(src)
- print("[write] updated", PYFILE, "changes:", total)
- try:
- compile(src, PYFILE, "exec")
- print("[compile] OK:", PYFILE)
- except SyntaxError as e:
- print("[compile] FAILED at line", e.lineno, ":", e.msg)
- sys.exit(2)
- PY
- else
- echo "[3/4] Skipped: MultiGPU file not found ($MGPU_WANVIDEO_PY)."
- fi
- # --- [4/4] Harden MultiTalk wav2vec ingestion path --------------------------
- # Makes node robust across wav2vec variants/signatures and optional torchaudio.
- echo "[4/4] Harden MultiTalk wav2vec (seq_len-smart, torchaudio optional): $NODES_PY"
- python3 - <<'PY'
- import io, re
- FILE = "/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper/multitalk/nodes.py"
- src = io.open(FILE, "r", encoding="utf-8", errors="ignore").read()
- io.open(FILE + ".mtbak_unified_v3_before", "w", encoding="utf-8").write(src)
- print("[backup]", FILE + ".mtbak_unified_v3_before")
- changes = 0
- # Guarded import of a patched wav2vec model if present; remain tolerant if not
- if "PatchedWav2Vec2Model" not in src:
- m = re.search(r'(?m)^from\s+\.\.utils\s+import\s+log\s*$', src)
- ins = (
- "try:\n"
- " from .wav2vec2 import Wav2Vec2Model as PatchedWav2Vec2Model\n"
- "except Exception:\n"
- " PatchedWav2Vec2Model = None\n"
- )
- src = src[:m.end()] + "\n" + ins + src[m.end():] if m else ins + src
- changes += 1
- print("[import] guarded PatchedWav2Vec2Model")
- # Disable any restrictive 'tencent-only' model type guards, if present
- new, n = re.subn(
- r'(?m)^\s*if\s+not\s*[\'"]tencent[\'"]\s+in\s+model_type\.lower\(\):\s*\n\s*raise\s+ValueError\([^\n]+\)\s*',
- " # (patched) disable tencent-only guard\n",
- src); src, changes = new, changes + n
- # Make torchaudio optional; require 16 kHz input if it is missing
- new, n = re.subn(r'(?m)^(\s*)import\s+torchaudio\s*$',
- r'\1try:\n\1 import torchaudio\n\1except Exception:\n\1 torchaudio = None',
- src); src, changes = new, changes + n
- # Feature extractor fallback: some packs expose 'processor' instead
- new, n = re.subn(
- r'(?m)^\s*wav2vec_feature_extractor\s*=\s*wav2vec_model\["feature_extractor"\]\s*$',
- ' wav2vec_feature_extractor = wav2vec_model.get("feature_extractor") or wav2vec_model.get("processor")',
- src); src, changes = new, changes + n
- # Safer resample block: error clearly if torchaudio is unavailable
- new, n = re.subn(
- r'(?m)^(\s*)if\s+sample_rate\s*!=\s*16000:\s*\n\1\s*audio_input\s*=\s*torchaudio\.functional\.resample\([^\n]+\)\s*$',
- r'\1if sample_rate != 16000:\n'
- r'\1 if torchaudio is None:\n'
- r'\1 raise ImportError("torchaudio is required for resampling non-16kHz audio. Provide 16kHz input or install torchaudio.")\n'
- r'\1 audio_input = torchaudio.functional.resample(audio_input, sample_rate, sr)',
- src); src, changes = new, changes + n
- # Accept wav2vec forward() with or without a 'seq_len' kwarg
- import re as _re
- pat = _re.compile(r'(?m)^(?P<i>\s*)embeddings\s*=\s*wav2vec\([^\n]*seq_len[^\n]*\)\s*$')
- m = pat.search(src)
- if m:
- i = m.group("i")
- repl = (
- f"{i}try:\n"
- f"{i} outputs = wav2vec(audio_feature.to(dtype), seq_len=int(video_length), output_hidden_states=True)\n"
- f"{i}except TypeError:\n"
- f"{i} outputs = wav2vec(audio_feature.to(dtype), output_hidden_states=True)"
- )
- src = src[:m.start()] + repl + src[m.end():]
- changes += 1
- # Unify 'hidden_states' access and resample time dim to video_length as needed
- if "hidden_states = getattr(outputs, 'hidden_states'" not in src:
- new = src.replace("wav2vec.to(offload_device)",
- "wav2vec.to(offload_device)\n"
- "\n"
- " # Unify hidden_states; interpolate to video_length if needed\n"
- " hidden_states = getattr(outputs, 'hidden_states', None)\n"
- " if hidden_states is None and isinstance(outputs, (list, tuple)):\n"
- " hidden_states = outputs\n"
- " if hidden_states is None:\n"
- " raise RuntimeError('wav2vec output does not contain hidden_states')\n"
- " target_len = int(video_length)\n"
- " if target_len > 0 and len(hidden_states) > 0 and hidden_states[0].shape[1] != target_len:\n"
- " resampled = []\n"
- " for h in hidden_states:\n"
- " ht = h.transpose(1, 2)\n"
- " ht = torch.nn.functional.interpolate(ht, size=target_len, mode='linear', align_corners=False)\n"
- " resampled.append(ht.transpose(1, 2))\n"
- " hidden_states = tuple(resampled)")
- if new != src:
- src = new; changes += 1
- # Replace downstream references to 'embeddings.hidden_states' with 'hidden_states'
- src2 = re.sub(r'(?m)embeddings\.hidden_states', 'hidden_states', src)
- if src2 != src:
- src = src2; changes += 1
- src2, n = re.subn(r'(?m)if\s+len\(\s*embeddings\s*\)\s*==\s*0\s*:', 'if len(hidden_states) == 0:', src)
- src, changes = src2, changes + n
- if changes:
- io.open(FILE, "w", encoding="utf-8").write(src)
- print(f"[apply] {changes} edits written")
- else:
- print("[apply] no edits needed")
- # Syntax check
- import sys
- try:
- compile(src, FILE, "exec")
- print("[compile] OK:", FILE)
- except SyntaxError as e:
- print("[compile] FAILED at line", e.lineno, ":", e.msg)
- sys.exit(2)
- PY
- # --- Operator guidance / quick checks ---------------------------------------
- echo "----------------------------------------------------------------------------"
- echo "Restart ComfyUI. On model load you should see one of:"
- echo " [MultiTalk] Grafted audio layers in WanVideoModelLoader.loadmodel"
- echo " [MultiTalk] Grafted audio layers in WanVideoModelLoader.loadmodel (MGPU)"
- echo " [MultiTalk] Grafted audio_proj + audio_cross_attn at load time (shim)."
- echo "If you still get 'audio_proj' missing, paste:"
- echo " nl -ba '$LOADER_PY' | sed -n '700,940p'"
- echo " [ -f '$MGPU_WANVIDEO_PY' ] && nl -ba '$MGPU_WANVIDEO_PY' | sed -n '1,220p'"
Advertisement
Add Comment
Please, Sign In to add comment