Guest User

Untitled

a guest
Aug 19th, 2025
43
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 26.26 KB | None | 0 0
  1. # === ComfyUI WanVideoWrapper — Unified MultiTalk Patcher v3 ==================
  2. # Purpose:
  3. # Enable MultiTalk on Wan 2.2 I2V 14B GGUF exports by grafting an audio
  4. # projection + audio cross-attention path at model load time, and make
  5. # wav2vec ingestion robust. Safetensors builds are unaffected.
  6. #
  7. # Key fixes:
  8. # • Provide/attach audio_proj + audio_cross_attn if missing (GGUF gap).
  9. # • Apply grafts in BOTH single-GPU and Multi-GPU loaders.
  10. # • Harden wav2vec path (seq_len signature diffs, optional torchaudio, unify
  11. # hidden_states and resample to video_length).
  12. #
  13. # Safety:
  14. # • Idempotent; backups written once as *.mtbak_unified_v3_before.
  15. # • Clear logs on graft execution to simplify troubleshooting.
  16. # ============================================================================
  17.  
  18. set -euo pipefail # Fail fast, treat unset vars as errors
  19.  
  20. # --- Paths (override via env if your layout differs) -------------------------
  21. WRAP_DIR="${WRAP_DIR:-/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper}"
  22. MGPU_DIR="${MGPU_DIR:-/workspace/ComfyUI/custom_nodes/comfyui-multigpu}"
  23.  
  24. MOD_DIR="$WRAP_DIR/wanvideo/modules" # Where the shim will live
  25. MULTITALK_DIR="$WRAP_DIR/multitalk" # MultiTalk nodes location
  26.  
  27. LOADER_PY="$WRAP_DIR/nodes_model_loading.py" # Single-GPU loader
  28. MGPU_WANVIDEO_PY="$MGPU_DIR/wanvideo.py" # Multi-GPU loader (if present)
  29. NODES_PY="$MULTITALK_DIR/nodes.py" # MultiTalk node definitions
  30. AP_SHIM="$MOD_DIR/audio_proj_shim.py" # Shim target path
  31.  
  32. # --- Sanity checks -----------------------------------------------------------
  33. command -v python3 >/dev/null || { echo "python3 is required"; exit 1; }
  34. test -d "$WRAP_DIR" || { echo "WanVideoWrapper not found: $WRAP_DIR"; exit 1; }
  35. test -f "$LOADER_PY" || { echo "Missing: $LOADER_PY"; exit 1; }
  36. test -f "$NODES_PY" || { echo "Missing: $NODES_PY"; exit 1; }
  37. mkdir -p "$MOD_DIR" "$MULTITALK_DIR"
  38.  
  39. # --- [1/4] Write/update the runtime shim that performs the graft -------------
  40. # The shim:
  41. # • Unwraps the underlying Wan model from the patcher.
  42. # • Ensures 'audio_proj' exists (uses pack proj if provided; Identity otherwise).
  43. # • Adds 'audio_cross_attn' and 'norm_x' per block if missing.
  44. # • Loads any provided state dict softly and aligns device/dtype.
  45. # • Marks _mt_grafted to avoid repeated work across reloads.
  46. echo "[1/4] Ensure shim: $AP_SHIM"
  47. python3 - <<'PY'
  48. import os, io
  49. dst = os.environ.get("AP_SHIM", "/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper/wanvideo/modules/audio_proj_shim.py")
  50. src = r'''# Auto-generated by unified MultiTalk patcher v3
  51. # Role: Provide a late-binding graft for MultiTalk audio ingestion when
  52. # Wan GGUF exports lack audio_proj + audio_cross_attn.
  53.  
  54. import torch
  55. from torch import nn
  56. from ...multitalk.multitalk import SingleStreamMultiAttention
  57. from .model import WanLayerNorm, WanRMSNorm
  58.  
  59. __all__ = ["setup_multitalk_for_model"]
  60.  
  61. def _unwrap_model(patcher_or_model):
  62. """Locate the underlying Wan model regardless of wrapper/patcher type."""
  63. m = getattr(patcher_or_model, "model", None)
  64. if m is None and hasattr(patcher_or_model, "diffusion_model"):
  65. m = patcher_or_model
  66. if m is None:
  67. for k in ("_model", "_unet"):
  68. m = getattr(patcher_or_model, k, None)
  69. if m is not None:
  70. break
  71. if m is None:
  72. raise RuntimeError("Could not unwrap model from patcher")
  73. return m
  74.  
  75. def setup_multitalk_for_model(patcher, multitalk_model=None):
  76. """Attach audio projection + cross-attn only if missing (idempotent)."""
  77. model = _unwrap_model(patcher)
  78.  
  79. # Skip if we already grafted this instance
  80. if getattr(model, "_mt_grafted", False):
  81. return patcher
  82.  
  83. # If blocks already have audio cross-attn, assume prior integration
  84. try:
  85. if hasattr(model.diffusion_model.blocks[0], "audio_cross_attn"):
  86. model._mt_grafted = True
  87. return patcher
  88. except Exception:
  89. pass
  90.  
  91. # Need a MultiTalk pack to source dimensions/weights when available
  92. if multitalk_model is None:
  93. return patcher
  94.  
  95. # Prefer a projection supplied by the pack (proj_model); fall back to Identity
  96. proj = multitalk_model.get("proj_model", None) if isinstance(multitalk_model, dict) else None
  97. if not hasattr(model, "audio_proj"):
  98. if proj is None:
  99. # Keep it conservative; dimension discovery happens downstream
  100. proj = nn.Identity()
  101. model.audio_proj = proj
  102.  
  103. # Attach per-block cross-attn + norm if absent
  104. try:
  105. blocks = model.diffusion_model.blocks
  106. except Exception as e:
  107. raise RuntimeError(f"[MultiTalk] Unexpected WanModel structure: {e}")
  108.  
  109. for blk in blocks:
  110. dim = getattr(blk, "dim", None) or getattr(blk, "hidden_size", None)
  111. if dim is None:
  112. raise RuntimeError("[MultiTalk] block.dim missing")
  113. num_heads = getattr(blk, "num_heads", 8)
  114. eps = getattr(blk, "eps", 1e-6)
  115. attention_mode = getattr(blk, "attention_mode", "sdpa")
  116.  
  117. if not hasattr(blk, "audio_cross_attn"):
  118. blk.audio_cross_attn = SingleStreamMultiAttention(
  119. dim=dim, encoder_hidden_states_dim=768, num_heads=num_heads,
  120. qk_norm=False, qkv_bias=True, eps=eps, norm_layer=WanRMSNorm,
  121. attention_mode=attention_mode,
  122. )
  123. if not hasattr(blk, "norm_x"):
  124. blk.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True)
  125.  
  126. # Best-effort soft load of any provided state dict
  127. try:
  128. sd = multitalk_model.get("sd", {}) if isinstance(multitalk_model, dict) else {}
  129. model.load_state_dict(sd, strict=False)
  130. except Exception:
  131. pass
  132.  
  133. # Match device/dtype of the main model params
  134. try:
  135. ref = next(model.diffusion_model.parameters())
  136. model.to(ref.device, dtype=ref.dtype)
  137. except Exception:
  138. pass
  139.  
  140. model._mt_grafted = True
  141. print("[MultiTalk] Grafted audio_proj + audio_cross_attn at load time (shim).")
  142. return patcher
  143. '''
  144. os.makedirs(os.path.dirname(dst), exist_ok=True)
  145. old = io.open(dst, "r", encoding="utf-8", errors="ignore").read() if os.path.exists(dst) else None
  146. if old != src:
  147. io.open(dst, "w", encoding="utf-8").write(src)
  148. print(f"[shim] wrote {dst}")
  149. else:
  150. print(f"[shim] up-to-date {dst}")
  151. PY
  152.  
  153. # --- [2/4] Patch the single-GPU loader to call the shim at the right time ----
  154. # Also includes minor hygiene (duplicate class collapse / INPUT_TYPES) and
  155. # injects the graft just before the final return in loadmodel(...).
  156. echo "[2/4] Patch WanVideoWrapper loader + LoRA class (nodes_model_loading.py)"
  157. python3 - <<'PY'
  158. import io, re, sys, os
  159.  
  160. PYFILE = "/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper/nodes_model_loading.py"
  161. src = io.open(PYFILE, "r", encoding="utf-8", errors="ignore").read()
  162. io.open(PYFILE + ".mtbak_unified_v3_before", "w", encoding="utf-8").write(src)
  163. print("[backup]", PYFILE + ".mtbak_unified_v3_before")
  164.  
  165. changed = 0
  166.  
  167. # Ensure we can import our shim locally
  168. if "setup_multitalk_for_model" not in src:
  169. ins = "from .wanvideo.modules.audio_proj_shim import setup_multitalk_for_model\n"
  170. m = re.search(r'(?m)^from\s+\.utils\s+import\s+log\s*$', src)
  171. if m:
  172. src = src[:m.end()] + "\n" + ins + src[m.end():]
  173. else:
  174. top = re.search(r'(?m)^(?:from|import)\s+[^\n]+\n(?:\s*(?:from|import)\s+[^\n]+\n)*', src)
  175. src = src[:top.end()] + ins + src[top.end():] if top else ins + src
  176. changed += 1
  177. print("[import] added loader shim import")
  178.  
  179. # Utilities to find / edit class blocks
  180. def find_class_block(code, name):
  181. m = re.search(r'(?m)^class\s+'+re.escape(name)+r'\s*:\s*$', code)
  182. if not m: return None
  183. s = m.start(); tail = code[m.end():]
  184. n = re.search(r'(?m)^(class\s+\w+|#region\b|NODE_CLASS_MAPPINGS\s*=)', tail)
  185. e = m.end() + (n.start() if n else len(tail))
  186. return (s,e)
  187.  
  188. def class_body_indent(txt):
  189. for L in txt.splitlines()[1:]:
  190. if L.strip() and not L.lstrip().startswith("#"):
  191. return re.match(r'^(\s*)', L).group(1)
  192. return " "
  193.  
  194. # Collapse duplicate WanVideoSetLoRAs definitions (defensive)
  195. blocks = [ (m.start(), None) for m in re.finditer(r'(?m)^class\s+WanVideoSetLoRAs\s*:\s*$', src) ]
  196. if len(blocks) > 1:
  197. positions = []
  198. for m in re.finditer(r'(?m)^class\s+WanVideoSetLoRAs\s*:\s*$', src):
  199. s = m.start(); tail = src[m.end():]
  200. n = re.search(r'(?m)^(class\s+\w+|#region\b|NODE_CLASS_MAPPINGS\s*=)', tail)
  201. e = m.end() + (n.start() if n else len(tail))
  202. positions.append((s,e))
  203. print(f"[merge] Found {len(positions)} WanVideoSetLoRAs blocks; collapsing.")
  204. first_s, first_e = positions[0]
  205. first = src[first_s:first_e]
  206. bi = class_body_indent(first)
  207. # Try to salvage a 'setlora' implementation from later copies
  208. method = None
  209. for s,e in positions[1:]:
  210. blk = src[s:e]
  211. m = re.search(r'(?m)^\s*def\s+setlora\s*\(', blk)
  212. if m:
  213. method = blk[m.start():].rstrip() + "\n"
  214. def_i = re.match(r'^(\s*)', method).group(1)
  215. method = "\n".join([L[len(def_i):] if L.startswith(def_i) else L for L in method.splitlines()]) + "\n"
  216. break
  217. # Drop the later duplicates
  218. for s,e in reversed(positions[1:]):
  219. src = src[:s] + src[e:]
  220. # Re-insert setlora if missing
  221. s,e = find_class_block(src, "WanVideoSetLoRAs")
  222. if s is not None and method and re.search(r'(?m)^\s*def\s+setlora\s*\(', src[s:e]) is None:
  223. insert_at = e
  224. src = src[:insert_at] + ("\n" + bi + method.replace("\n", "\n"+bi)).rstrip() + "\n" + src[insert_at:]
  225. changed += 1
  226.  
  227. # Ensure INPUT_TYPES is a @classmethod
  228. block = find_class_block(src, "WanVideoSetLoRAs")
  229. if block:
  230. s,e = block
  231. klass = src[s:e]
  232. if re.search(r'(?m)^\s*@classmethod\s*\n\s*def\s+INPUT_TYPES\s*\(', klass) is None:
  233. if re.search(r'(?m)^\s*def\s+INPUT_TYPES\s*\(', klass):
  234. new = re.sub(r'(?m)^(\s*)def\s+INPUT_TYPES\s*\(', r'\1@classmethod\n\1def INPUT_TYPES(', klass, count=1)
  235. src = src[:s] + new + src[e:]
  236. else:
  237. bi = class_body_indent(klass)
  238. ins = (
  239. f"{bi}@classmethod\n"
  240. f"{bi}def INPUT_TYPES(s):\n"
  241. f"{bi} return {{\n"
  242. f"{bi} 'required': {{ 'model': ('WANVIDEOMODEL', ), }},\n"
  243. f"{bi} 'optional': {{ 'lora': ('WANVIDLORA', ), }},\n"
  244. f"{bi} }}\n\n"
  245. )
  246. line_end = s + src[s:e].find("\n") + 1
  247. src = src[:line_end] + ins + src[line_end:]
  248. changed += 1
  249. print("[fix] ensured INPUT_TYPES in WanVideoSetLoRAs")
  250.  
  251. # Inject the graft just before the final return of loadmodel(...)
  252. def inject_graft(code, class_name):
  253. blk = find_class_block(code, class_name)
  254. if not blk:
  255. print(f"[loader] class {class_name} not found");
  256. return code, 0
  257. s,e = blk
  258. klass = code[s:e]
  259. mdef = re.search(r'(?m)^\s*def\s+loadmodel\s*\(', klass)
  260. if not mdef:
  261. print(f"[loader] def loadmodel(...) not found in {class_name}")
  262. return code, 0
  263. dstart = s + mdef.start()
  264. m_indent = re.search(r'(?m)^(?P<i>\s*)def\s+loadmodel\s*\(', code[dstart:]).group('i')
  265. body_i = m_indent + " "
  266. method_text = code[dstart:e]
  267. if "setup_multitalk_for_model(" in method_text:
  268. print(f"[loader] graft already present in {class_name}")
  269. return code, 0
  270.  
  271. last = None
  272. for m in re.finditer(r'(?m)^\s*return\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*,\s*\)\s*$', method_text):
  273. last = m
  274. var = "patcher"
  275. if last:
  276. var = last.group(1)
  277. graft = (
  278. f"{body_i}# ---- MultiTalk graft (loader-time) ----\n"
  279. f"{body_i}if 'multitalk_model' in locals() and multitalk_model is not None:\n"
  280. f"{body_i} _need = False\n"
  281. f"{body_i} try:\n"
  282. f"{body_i} _need = (not hasattr({var}.model if hasattr({var}, 'model') else {var}, 'audio_proj') or \n"
  283. f"{body_i} not hasattr(({var}.model if hasattr({var}, 'model') else {var}).diffusion_model.blocks[0], 'audio_cross_attn'))\n"
  284. f"{body_i} except Exception:\n"
  285. f"{body_i} _need = True\n"
  286. f"{body_i} if _need:\n"
  287. f"{body_i} try:\n"
  288. f"{body_i} {var} = setup_multitalk_for_model({var}, multitalk_model)\n"
  289. f"{body_i} try:\n"
  290. f"{body_i} log.info('[MultiTalk] Grafted audio layers in {class_name}.loadmodel')\n"
  291. f"{body_i} except Exception:\n"
  292. f"{body_i} print('[MultiTalk] Grafted audio layers in {class_name}.loadmodel')\n"
  293. f"{body_i} except Exception as _e:\n"
  294. f"{body_i} try:\n"
  295. f"{body_i} log.warning(f'[MultiTalk] Graft failed in {class_name}: {{_e}}')\n"
  296. f"{body_i} except Exception:\n"
  297. f"{body_i} print(f'[MultiTalk] Graft failed in {class_name}: {{_e}}')\n"
  298. f"{body_i}return ({var},)\n"
  299. )
  300. rel = dstart
  301. abs_s = rel + last.start()
  302. abs_e = rel + last.end()
  303. code = code[:abs_s] + graft + code[abs_e:]
  304. return code, 1
  305. return code, 0
  306.  
  307. src, ch = inject_graft(src, "WanVideoModelLoader"); changed += ch
  308.  
  309. if changed:
  310. io.open(PYFILE, "w", encoding="utf-8").write(src)
  311. print("[write] updated", PYFILE, "changes:", changed)
  312.  
  313. # Syntax check for safety
  314. try:
  315. compile(src, PYFILE, "exec")
  316. print("[compile] OK:", PYFILE)
  317. except SyntaxError as e:
  318. print("[compile] FAILED at line", e.lineno, ":", e.msg)
  319. sys.exit(2)
  320. PY
  321.  
  322. # --- [3/4] Patch the Multi-GPU loader (if present) --------------------------
  323. # Adds a small helper to import the shim even with hyphenated package names,
  324. # and injects the same loader-time graft as single-GPU.
  325. echo "[3/4] Patch MultiGPU loader(s) (if present): $MGPU_WANVIDEO_PY"
  326. if [ -f "$MGPU_WANVIDEO_PY" ]; then
  327. python3 - <<'PY'
  328. import io, re, sys, os, textwrap
  329.  
  330. PYFILE = "/workspace/ComfyUI/custom_nodes/comfyui-multigpu/wanvideo.py"
  331. src = io.open(PYFILE, "r", encoding="utf-8", errors="ignore").read()
  332. io.open(PYFILE + ".mtbak_unified_v3_before", "w", encoding="utf-8").write(src)
  333. print("[backup]", PYFILE + ".mtbak_unified_v3_before")
  334.  
  335. changed = 0
  336.  
  337. # Helper to import shim directly from file path (works around package name)
  338. if "def _mt_load_setup_fn(" not in src and "setup_multitalk_for_model(" not in src:
  339. helper = textwrap.dedent(r'''
  340. # --- injected by MultiTalk patcher v3 (lazy-setup import + fallback) ---
  341. import importlib.util as _importlib_util, os as _os, sys as _sys
  342. def _mt_load_setup_fn():
  343. root = _os.path.abspath(_os.path.join(_os.path.dirname(__file__), "..", "ComfyUI-WanVideoWrapper", "wanvideo", "modules", "audio_proj_shim.py"))
  344. if _os.path.exists(root):
  345. spec = _importlib_util.spec_from_file_location("_mt_audio_proj_shim", root)
  346. if spec and spec.loader:
  347. mod = _importlib_util.module_from_spec(spec)
  348. try:
  349. spec.loader.exec_module(mod)
  350. return getattr(mod, "setup_multitalk_for_model", None)
  351. except Exception:
  352. pass
  353. return None
  354. _setup_multitalk_for_model = _mt_load_setup_fn()
  355.  
  356. def setup_multitalk_for_model(patcher, multitalk_model=None):
  357. if _setup_multitalk_for_model is not None:
  358. return _setup_multitalk_for_model(patcher, multitalk_model)
  359. # Minimal fallback: ensure audio_proj exists so wrapper path can run
  360. try:
  361. model = getattr(patcher, "model", None) or patcher
  362. if not hasattr(model, "audio_proj"):
  363. import torch.nn as nn
  364. model.audio_proj = nn.Identity()
  365. # Tag the model to avoid repeated work
  366. setattr(model, "_mt_grafted", True)
  367. try:
  368. log.info("[MultiTalk] Grafted audio_proj (MGPU fallback).")
  369. except Exception:
  370. print("[MultiTalk] Grafted audio_proj (MGPU fallback).")
  371. except Exception as _e:
  372. try:
  373. log.warning(f"[MultiTalk] MGPU fallback graft failed: {_e}")
  374. except Exception:
  375. print(f"[MultiTalk] MGPU fallback graft failed: {_e}")
  376. return patcher
  377. # --- end injected helper ---
  378. ''').strip("\n") + "\n\n"
  379. m = re.search(r'(?ms)^(?:from|import)\s+[^\n]+\n(?:\s*(?:from|import)\s+[^\n]+\n)*', src)
  380. if m:
  381. src = src[:m.end()] + helper + src[m.end():]
  382. else:
  383. src = helper + src
  384. changed += 1
  385. print("[mgpu] injected helper setup loader")
  386.  
  387. # Reuse class scanner and injection for MGPU loaders
  388. def find_class_block(code, name):
  389. m = re.search(r'(?m)^class\s+'+re.escape(name)+r'\s*:\s*$', code)
  390. if not m: return None
  391. s = m.start(); tail = code[m.end():]
  392. n = re.search(r'(?m)^(class\s+\w+|#region\b|NODE_CLASS_MAPPINGS\s*=)', tail)
  393. e = m.end() + (n.start() if n else len(tail))
  394. return (s,e)
  395.  
  396. def inject_into_loader(code, class_name):
  397. blk = find_class_block(code, class_name)
  398. if not blk:
  399. return code, 0
  400. s,e = blk
  401. klass = code[s:e]
  402. mdef = re.search(r'(?m)^\s*def\s+loadmodel\s*\(', klass)
  403. if not mdef:
  404. return code, 0
  405. dstart = s + mdef.start()
  406. m_indent = re.search(r'(?m)^(?P<i>\s*)def\s+loadmodel\s*\(', code[dstart:]).group('i')
  407. body_i = m_indent + " "
  408. method_text = code[dstart:e]
  409. if "setup_multitalk_for_model(" in method_text:
  410. return code, 0
  411.  
  412. last = None
  413. for m in re.finditer(r'(?m)^\s*return\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*,\s*\)\s*$', method_text):
  414. last = m
  415. var = "patcher"
  416. if last:
  417. var = last.group(1)
  418. graft = (
  419. f"{body_i}# ---- MultiTalk graft (MGPU loader-time) ----\n"
  420. f"{body_i}if 'multitalk_model' in locals() and multitalk_model is not None:\n"
  421. f"{body_i} _need = False\n"
  422. f"{body_i} try:\n"
  423. f"{body_i} _core = {var}.model if hasattr({var}, 'model') else {var}\n"
  424. f"{body_i} _need = (not hasattr(_core, 'audio_proj') or not hasattr(_core.diffusion_model.blocks[0], 'audio_cross_attn'))\n"
  425. f"{body_i} except Exception:\n"
  426. f"{body_i} _need = True\n"
  427. f"{body_i} if _need:\n"
  428. f"{body_i} try:\n"
  429. f"{body_i} {var} = setup_multitalk_for_model({var}, multitalk_model)\n"
  430. f"{body_i} try:\n"
  431. f"{body_i} log.info('[MultiTalk] Grafted audio layers in {class_name}.loadmodel (MGPU)')\n"
  432. f"{body_i} except Exception:\n"
  433. f"{body_i} print('[MultiTalk] Grafted audio layers in {class_name}.loadmodel (MGPU)')\n"
  434. f"{body_i} except Exception as _e:\n"
  435. f"{body_i} try:\n"
  436. f"{body_i} log.warning(f'[MultiTalk] Graft failed in {class_name} (MGPU): {{_e}}')\n"
  437. f"{body_i} except Exception:\n"
  438. f"{body_i} print(f'[MultiTalk] Graft failed in {class_name} (MGPU): {{_e}}')\n"
  439. f"{body_i}return ({var},)\n"
  440. )
  441. rel = dstart
  442. abs_s = rel + last.start()
  443. abs_e = rel + last.end()
  444. code = code[:abs_s] + graft + code[abs_e:]
  445. return code, 1
  446. return code, 0
  447.  
  448. total = 0
  449. for cname in ("WanVideoModelLoader", "WanVideoModelLoader_2"):
  450. src, ch = inject_into_loader(src, cname); total += ch
  451. if ch: print(f"[mgpu] graft inserted in {cname}")
  452.  
  453. if total:
  454. io.open(PYFILE, "w", encoding="utf-8").write(src)
  455. print("[write] updated", PYFILE, "changes:", total)
  456.  
  457. try:
  458. compile(src, PYFILE, "exec")
  459. print("[compile] OK:", PYFILE)
  460. except SyntaxError as e:
  461. print("[compile] FAILED at line", e.lineno, ":", e.msg)
  462. sys.exit(2)
  463. PY
  464. else
  465. echo "[3/4] Skipped: MultiGPU file not found ($MGPU_WANVIDEO_PY)."
  466. fi
  467.  
  468. # --- [4/4] Harden MultiTalk wav2vec ingestion path --------------------------
  469. # Makes node robust across wav2vec variants/signatures and optional torchaudio.
  470. echo "[4/4] Harden MultiTalk wav2vec (seq_len-smart, torchaudio optional): $NODES_PY"
  471. python3 - <<'PY'
  472. import io, re
  473.  
  474. FILE = "/workspace/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper/multitalk/nodes.py"
  475. src = io.open(FILE, "r", encoding="utf-8", errors="ignore").read()
  476. io.open(FILE + ".mtbak_unified_v3_before", "w", encoding="utf-8").write(src)
  477. print("[backup]", FILE + ".mtbak_unified_v3_before")
  478. changes = 0
  479.  
  480. # Guarded import of a patched wav2vec model if present; remain tolerant if not
  481. if "PatchedWav2Vec2Model" not in src:
  482. m = re.search(r'(?m)^from\s+\.\.utils\s+import\s+log\s*$', src)
  483. ins = (
  484. "try:\n"
  485. " from .wav2vec2 import Wav2Vec2Model as PatchedWav2Vec2Model\n"
  486. "except Exception:\n"
  487. " PatchedWav2Vec2Model = None\n"
  488. )
  489. src = src[:m.end()] + "\n" + ins + src[m.end():] if m else ins + src
  490. changes += 1
  491. print("[import] guarded PatchedWav2Vec2Model")
  492.  
  493. # Disable any restrictive 'tencent-only' model type guards, if present
  494. new, n = re.subn(
  495. r'(?m)^\s*if\s+not\s*[\'"]tencent[\'"]\s+in\s+model_type\.lower\(\):\s*\n\s*raise\s+ValueError\([^\n]+\)\s*',
  496. " # (patched) disable tencent-only guard\n",
  497. src); src, changes = new, changes + n
  498.  
  499. # Make torchaudio optional; require 16 kHz input if it is missing
  500. new, n = re.subn(r'(?m)^(\s*)import\s+torchaudio\s*$',
  501. r'\1try:\n\1 import torchaudio\n\1except Exception:\n\1 torchaudio = None',
  502. src); src, changes = new, changes + n
  503.  
  504. # Feature extractor fallback: some packs expose 'processor' instead
  505. new, n = re.subn(
  506. r'(?m)^\s*wav2vec_feature_extractor\s*=\s*wav2vec_model\["feature_extractor"\]\s*$',
  507. ' wav2vec_feature_extractor = wav2vec_model.get("feature_extractor") or wav2vec_model.get("processor")',
  508. src); src, changes = new, changes + n
  509.  
  510. # Safer resample block: error clearly if torchaudio is unavailable
  511. new, n = re.subn(
  512. r'(?m)^(\s*)if\s+sample_rate\s*!=\s*16000:\s*\n\1\s*audio_input\s*=\s*torchaudio\.functional\.resample\([^\n]+\)\s*$',
  513. r'\1if sample_rate != 16000:\n'
  514. r'\1 if torchaudio is None:\n'
  515. r'\1 raise ImportError("torchaudio is required for resampling non-16kHz audio. Provide 16kHz input or install torchaudio.")\n'
  516. r'\1 audio_input = torchaudio.functional.resample(audio_input, sample_rate, sr)',
  517. src); src, changes = new, changes + n
  518.  
  519. # Accept wav2vec forward() with or without a 'seq_len' kwarg
  520. import re as _re
  521. pat = _re.compile(r'(?m)^(?P<i>\s*)embeddings\s*=\s*wav2vec\([^\n]*seq_len[^\n]*\)\s*$')
  522. m = pat.search(src)
  523. if m:
  524. i = m.group("i")
  525. repl = (
  526. f"{i}try:\n"
  527. f"{i} outputs = wav2vec(audio_feature.to(dtype), seq_len=int(video_length), output_hidden_states=True)\n"
  528. f"{i}except TypeError:\n"
  529. f"{i} outputs = wav2vec(audio_feature.to(dtype), output_hidden_states=True)"
  530. )
  531. src = src[:m.start()] + repl + src[m.end():]
  532. changes += 1
  533.  
  534. # Unify 'hidden_states' access and resample time dim to video_length as needed
  535. if "hidden_states = getattr(outputs, 'hidden_states'" not in src:
  536. new = src.replace("wav2vec.to(offload_device)",
  537. "wav2vec.to(offload_device)\n"
  538. "\n"
  539. " # Unify hidden_states; interpolate to video_length if needed\n"
  540. " hidden_states = getattr(outputs, 'hidden_states', None)\n"
  541. " if hidden_states is None and isinstance(outputs, (list, tuple)):\n"
  542. " hidden_states = outputs\n"
  543. " if hidden_states is None:\n"
  544. " raise RuntimeError('wav2vec output does not contain hidden_states')\n"
  545. " target_len = int(video_length)\n"
  546. " if target_len > 0 and len(hidden_states) > 0 and hidden_states[0].shape[1] != target_len:\n"
  547. " resampled = []\n"
  548. " for h in hidden_states:\n"
  549. " ht = h.transpose(1, 2)\n"
  550. " ht = torch.nn.functional.interpolate(ht, size=target_len, mode='linear', align_corners=False)\n"
  551. " resampled.append(ht.transpose(1, 2))\n"
  552. " hidden_states = tuple(resampled)")
  553. if new != src:
  554. src = new; changes += 1
  555.  
  556. # Replace downstream references to 'embeddings.hidden_states' with 'hidden_states'
  557. src2 = re.sub(r'(?m)embeddings\.hidden_states', 'hidden_states', src)
  558. if src2 != src:
  559. src = src2; changes += 1
  560. src2, n = re.subn(r'(?m)if\s+len\(\s*embeddings\s*\)\s*==\s*0\s*:', 'if len(hidden_states) == 0:', src)
  561. src, changes = src2, changes + n
  562.  
  563. if changes:
  564. io.open(FILE, "w", encoding="utf-8").write(src)
  565. print(f"[apply] {changes} edits written")
  566. else:
  567. print("[apply] no edits needed")
  568.  
  569. # Syntax check
  570. import sys
  571. try:
  572. compile(src, FILE, "exec")
  573. print("[compile] OK:", FILE)
  574. except SyntaxError as e:
  575. print("[compile] FAILED at line", e.lineno, ":", e.msg)
  576. sys.exit(2)
  577. PY
  578.  
  579. # --- Operator guidance / quick checks ---------------------------------------
  580. echo "----------------------------------------------------------------------------"
  581. echo "Restart ComfyUI. On model load you should see one of:"
  582. echo " [MultiTalk] Grafted audio layers in WanVideoModelLoader.loadmodel"
  583. echo " [MultiTalk] Grafted audio layers in WanVideoModelLoader.loadmodel (MGPU)"
  584. echo " [MultiTalk] Grafted audio_proj + audio_cross_attn at load time (shim)."
  585. echo "If you still get 'audio_proj' missing, paste:"
  586. echo " nl -ba '$LOADER_PY' | sed -n '700,940p'"
  587. echo " [ -f '$MGPU_WANVIDEO_PY' ] && nl -ba '$MGPU_WANVIDEO_PY' | sed -n '1,220p'"
  588.  
Advertisement
Add Comment
Please, Sign In to add comment