Guest User

hunyuan_video_packed.py - FramePack GTX1080Ti Fix

a guest
Apr 28th, 2025
58
1
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 49.73 KB | Software | 1 0
  1. from typing import Any, Dict, List, Optional, Tuple, Union
  2.  
  3. import torch
  4. import einops
  5. import torch.nn as nn
  6. import numpy as np
  7.  
  8. from diffusers.loaders import FromOriginalModelMixin
  9. from diffusers.configuration_utils import ConfigMixin, register_to_config
  10. from diffusers.loaders import PeftAdapterMixin
  11. from diffusers.utils import logging
  12. from diffusers.models.attention import FeedForward
  13. from diffusers.models.attention_processor import Attention
  14. from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
  15. from diffusers.models.modeling_outputs import Transformer2DModelOutput
  16. from diffusers.models.modeling_utils import ModelMixin
  17. from diffusers_helper.dit_common import LayerNorm
  18. from diffusers_helper.utils import zero_module
  19.  
  20.  
  21. enabled_backends = []
  22.  
  23. if torch.backends.cuda.flash_sdp_enabled():
  24.     enabled_backends.append("flash")
  25. if torch.backends.cuda.math_sdp_enabled():
  26.     enabled_backends.append("math")
  27. if torch.backends.cuda.mem_efficient_sdp_enabled():
  28.     enabled_backends.append("mem_efficient")
  29. if torch.backends.cuda.cudnn_sdp_enabled():
  30.     enabled_backends.append("cudnn")
  31.  
  32. print("Currently enabled native sdp backends:", enabled_backends)
  33.  
  34. try:
  35.     # raise NotImplementedError
  36.     from xformers.ops import memory_efficient_attention as xformers_attn_func
  37.     print('Xformers is installed!')
  38. except:
  39.     print('Xformers is not installed!')
  40.     xformers_attn_func = None
  41.  
  42. try:
  43.     # raise NotImplementedError
  44.     from flash_attn import flash_attn_varlen_func, flash_attn_func
  45.     print('Flash Attn is installed!')
  46. except:
  47.     print('Flash Attn is not installed!')
  48.     flash_attn_varlen_func = None
  49.     flash_attn_func = None
  50.  
  51. try:
  52.     # raise NotImplementedError
  53.     from sageattention import sageattn_varlen, sageattn
  54.     print('Sage Attn is installed!')
  55. except:
  56.     print('Sage Attn is not installed!')
  57.     sageattn_varlen = None
  58.     sageattn = None
  59.  
  60.  
  61. logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
  62.  
  63.  
  64. def pad_for_3d_conv(x, kernel_size):
  65.     b, c, t, h, w = x.shape
  66.     pt, ph, pw = kernel_size
  67.     pad_t = (pt - (t % pt)) % pt
  68.     pad_h = (ph - (h % ph)) % ph
  69.     pad_w = (pw - (w % pw)) % pw
  70.     return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
  71.  
  72.  
  73. def center_down_sample_3d(x, kernel_size):
  74.     # pt, ph, pw = kernel_size
  75.     # cp = (pt * ph * pw) // 2
  76.     # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
  77.     # xc = xp[cp]
  78.     # return xc
  79.     return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
  80.  
  81.  
  82. def get_cu_seqlens(text_mask, img_len):
  83.     batch_size = text_mask.shape[0]
  84.     text_len = text_mask.sum(dim=1)
  85.     max_len = text_mask.shape[1] + img_len
  86.  
  87.     cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
  88.  
  89.     for i in range(batch_size):
  90.         s = text_len[i] + img_len
  91.         s1 = i * max_len + s
  92.         s2 = (i + 1) * max_len
  93.         cu_seqlens[2 * i + 1] = s1
  94.         cu_seqlens[2 * i + 2] = s2
  95.  
  96.     return cu_seqlens
  97.  
  98.  
  99. def apply_rotary_emb_transposed(x, freqs_cis):
  100.     cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
  101.     x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
  102.     x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
  103.     out = x.float() * cos + x_rotated.float() * sin
  104.     out = out.to(x)
  105.     return out
  106.  
  107.  
  108. def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
  109.     if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
  110.         if sageattn is not None:
  111.             x = sageattn(q, k, v, tensor_layout='NHD')
  112.             return x
  113.  
  114.         # if flash_attn_func is not None:
  115.         #     x = flash_attn_func(q, k, v)
  116.         #     return x
  117.  
  118.         if xformers_attn_func is not None:
  119.             # x = xformers_attn_func(q, k, v)
  120.             # Cast to float16 before calling xformers as bf16 might not be supported
  121.             q_fp16 = q.to(torch.float16)
  122.             k_fp16 = k.to(torch.float16)
  123.             v_fp16 = v.to(torch.float16)
  124.             try:
  125.                 x = xformers_attn_func(q_fp16, k_fp16, v_fp16)
  126.                 x = x.to(q.dtype) # Cast back to original dtype
  127.                 return x
  128.             except NotImplementedError:
  129.                  # If xformers fails with float16, fall through to SDPA
  130.                  print("xFormers failed with float16, falling back to SDPA.")
  131.                  pass # Let SDPA handle it
  132.  
  133.         # Fallback to SDPA, ensure float16 or float32 as bfloat16 might not be supported natively on Turing
  134.         q_fallback = q.to(torch.float16) # Try float16 first for SDPA
  135.         k_fallback = k.to(torch.float16)
  136.         v_fallback = v.to(torch.float16)
  137.         try:
  138.             # Use PyTorch's built-in SDPA
  139.             x = torch.nn.functional.scaled_dot_product_attention(q_fallback.transpose(1, 2), k_fallback.transpose(1, 2), v_fallback.transpose(1, 2)).transpose(1, 2)
  140.             x = x.to(q.dtype) # Cast back to original dtype
  141.             return x
  142.  
  143.         # x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
  144.         # return x
  145.         except Exception as e_sdpa_fp16:
  146.             print(f"SDPA failed with float16: {e_sdpa_fp16}. Trying float32.")
  147.             # If float16 fails for SDPA (less likely but possible), try float32
  148.             q_fallback_fp32 = q.to(torch.float32)
  149.             k_fallback_fp32 = k.to(torch.float32)
  150.             v_fallback_fp32 = v.to(torch.float32)
  151.             try:
  152.                 x = torch.nn.functional.scaled_dot_product_attention(q_fallback_fp32.transpose(1, 2), k_fallback_fp32.transpose(1, 2), v_fallback_fp32.transpose(1, 2)).transpose(1, 2)
  153.                 x = x.to(q.dtype) # Cast back to original dtype
  154.                 return x
  155.             except Exception as e_sdpa_fp32:
  156.                  print(f"SDPA also failed with float32: {e_sdpa_fp32}. Raising original error.")
  157.                  raise e_sdpa_fp16 # Re-raise the float16 error if float32 also fails
  158.  
  159.     batch_size = q.shape[0]
  160.     # The view operations might need to happen *before* casting if cu_seqlens are involved
  161.     # Let's keep the original view logic for now
  162.     q_orig_shape = q.shape
  163.     k_orig_shape = k.shape
  164.     v_orig_shape = v.shape
  165.     q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
  166.     k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
  167.     v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
  168.  
  169.     if sageattn_varlen is not None:
  170.     #     x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
  171.     # elif flash_attn_varlen_func is not None:
  172.     #     x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
  173.         # Assuming sageattn handles dtypes correctly or needs its own casting
  174.         try:
  175.             x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
  176.         except Exception as e_sage:
  177.             print(f"Sage Attn failed: {e_sage}. Falling back.")
  178.             # Fallback logic needed if sage fails
  179.             x = None # Placeholder
  180.     else:
  181.     #     raise NotImplementedError('No Attn Installed!')
  182.     # x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
  183.         x = None # Initialize x if sageattn is not used
  184.  
  185.     #elif flash_attn_varlen_func is not None: # Commented out
  186.     #    x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
  187.  
  188.     if x is None and xformers_attn_func is not None: # Check if sage failed or wasn't used, and xformers is available
  189.         # Cast to float16 before calling xformers for variable length
  190.         q_fp16 = q.to(torch.float16)
  191.         k_fp16 = k.to(torch.float16)
  192.         v_fp16 = v.to(torch.float16)
  193.         try:
  194.              # xFormers memory_efficient_attention does not directly support cu_seqlens.
  195.              # The original code used flash_attn_varlen_func for this.
  196.              # We might need to pad/unpad manually or use a different xFormers interface if one exists for varlen.
  197.              # For now, trying the standard call, which might be incorrect for packed sequences.
  198.             print("Warning: Calling standard xformers attention for variable length sequence. This might be incorrect. Padding/unpadding or a dedicated varlen function might be needed.")
  199.             x = xformers_attn_func(q_fp16, k_fp16, v_fp16) # This is likely incorrect for varlen!
  200.             x = x.to(q.dtype) # Cast back
  201.         except NotImplementedError:
  202.             print("xFormers (standard) failed for varlen with float16, falling back to SDPA.")
  203.             x = None # Mark xFormers as failed
  204.         except Exception as e_xformers_varlen:
  205.             print(f"xFormers (standard) encountered an error for varlen: {e_xformers_varlen}. Falling back to SDPA.")
  206.             x = None # Mark xFormers as failed
  207.  
  208.  
  209.     if x is None: # If sage/flash/xformers all failed or were unavailable for varlen
  210.         print("No specialized attention backend worked for varlen, falling back to basic SDPA (potentially incorrect for packed sequences).")
  211.         q_fallback = q.to(torch.float16) # Use float16 for fallback
  212.         k_fallback = k.to(torch.float16)
  213.         v_fallback = v.to(torch.float16)
  214.         # Basic SDPA call - this does NOT respect cu_seqlens and will likely produce wrong results for packed attention.
  215.         # A correct implementation would require iterating through batches/sequences or finding a SDPA variant that handles cu_seqlens.
  216.         try:
  217.             x_list = []
  218.             current_idx_q = 0
  219.             current_idx_kv = 0
  220.             # Manually iterate based on cu_seqlens (this is slow and inefficient but more correct than a single SDPA call)
  221.             for i in range(batch_size):
  222.                  start_idx_q = cu_seqlens_q[2*i].item()
  223.                  end_idx_q = cu_seqlens_q[2*i+1].item()
  224.                  start_idx_kv = cu_seqlens_kv[2*i].item() # Assuming cu_seqlens_kv matches q structure here
  225.                  end_idx_kv = cu_seqlens_kv[2*i+1].item()
  226.  
  227.                  q_i = q_fallback[start_idx_q:end_idx_q].unsqueeze(0).transpose(1, 2) # Shape: (1, num_heads, seq_len_q, head_dim) -> (1, seq_len_q, num_heads, head_dim) -> (1, num_heads, seq_len_q, head_dim) for SDPA
  228.                  k_i = k_fallback[start_idx_kv:end_idx_kv].unsqueeze(0).transpose(1, 2)
  229.                  v_i = v_fallback[start_idx_kv:end_idx_kv].unsqueeze(0).transpose(1, 2)
  230.  
  231.                  # Reshape for SDPA: (batch=1, seq_len, num_heads, head_dim) -> (batch=1, num_heads, seq_len, head_dim)
  232.                  q_i = q_i.view(1, -1, q_orig_shape[-2], q_orig_shape[-1]).transpose(1, 2)
  233.                  k_i = k_i.view(1, -1, k_orig_shape[-2], k_orig_shape[-1]).transpose(1, 2)
  234.                  v_i = v_i.view(1, -1, v_orig_shape[-2], v_orig_shape[-1]).transpose(1, 2)
  235.  
  236.  
  237.                  out_i = torch.nn.functional.scaled_dot_product_attention(q_i, k_i, v_i).transpose(1, 2) # Output: (1, num_heads, seq_len_q, head_dim) -> (1, seq_len_q, num_heads, head_dim)
  238.                  # Reshape back to match original 'flattened' structure before view
  239.                  out_i = out_i.reshape(1, -1, q_orig_shape[-2] * q_orig_shape[-1]) # (1, seq_len_q, hidden_dim)
  240.                  x_list.append(out_i.squeeze(0))
  241.  
  242.  
  243.             # Pad shorter sequences to max_seqlen_q before cat
  244.             padded_x_list = []
  245.             for x_i in x_list:
  246.                 pad_len = max_seqlen_q - x_i.shape[0]
  247.                 if pad_len > 0:
  248.                      # Pad with zeros, adjust padding dims if needed
  249.                     padded_x_i = torch.nn.functional.pad(x_i, (0, 0, 0, pad_len)) # Pads the sequence length dim (dim 0)
  250.                     padded_x_list.append(padded_x_i)
  251.                 else:
  252.                     padded_x_list.append(x_i)
  253.  
  254.             x = torch.stack(padded_x_list, dim=0) # Shape: (batch_size, max_seqlen_q, hidden_dim)
  255.             x = x.to(q.dtype) # Cast back
  256.  
  257.         except Exception as e_sdpa_varlen:
  258.             print(f"Manual SDPA loop for varlen failed: {e_sdpa_varlen}")
  259.             # If even the manual loop fails, we are stuck. Raise error.
  260.             raise NotImplementedError("Could not execute attention with any backend for variable sequence lengths.")
  261.  
  262.  
  263.     # Reshape x back to (batch_size, max_seqlen_q, num_heads, head_dim) before final view?
  264.     # The SDPA fallback already outputs in (batch_size, seq_len, hidden_dim) format which matches the expected format before the final view.
  265.     # Ensure xformers/sage output is also in this format if they were used. (Assuming they are)
  266.  
  267.     # Final reshape expects input shape (batch_size, max_seqlen_q, hidden_dim)
  268.     x = x.view(batch_size, max_seqlen_q, *q_orig_shape[2:]) # Use original shape's head dim etc.
  269.     return x
  270.  
  271.  
  272. class HunyuanAttnProcessorFlashAttnDouble:
  273.     def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
  274.         cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
  275.  
  276.         query = attn.to_q(hidden_states)
  277.         key = attn.to_k(hidden_states)
  278.         value = attn.to_v(hidden_states)
  279.  
  280.         query = query.unflatten(2, (attn.heads, -1))
  281.         key = key.unflatten(2, (attn.heads, -1))
  282.         value = value.unflatten(2, (attn.heads, -1))
  283.  
  284.         query = attn.norm_q(query)
  285.         key = attn.norm_k(key)
  286.  
  287.         query = apply_rotary_emb_transposed(query, image_rotary_emb)
  288.         key = apply_rotary_emb_transposed(key, image_rotary_emb)
  289.  
  290.         encoder_query = attn.add_q_proj(encoder_hidden_states)
  291.         encoder_key = attn.add_k_proj(encoder_hidden_states)
  292.         encoder_value = attn.add_v_proj(encoder_hidden_states)
  293.  
  294.         encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
  295.         encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
  296.         encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
  297.  
  298.         encoder_query = attn.norm_added_q(encoder_query)
  299.         encoder_key = attn.norm_added_k(encoder_key)
  300.  
  301.         query = torch.cat([query, encoder_query], dim=1)
  302.         key = torch.cat([key, encoder_key], dim=1)
  303.         value = torch.cat([value, encoder_value], dim=1)
  304.  
  305.         hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
  306.         hidden_states = hidden_states.flatten(-2)
  307.  
  308.         txt_length = encoder_hidden_states.shape[1]
  309.         hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
  310.  
  311.         hidden_states = attn.to_out[0](hidden_states)
  312.         hidden_states = attn.to_out[1](hidden_states)
  313.         encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
  314.  
  315.         return hidden_states, encoder_hidden_states
  316.  
  317.  
  318. class HunyuanAttnProcessorFlashAttnSingle:
  319.     def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
  320.         cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
  321.  
  322.         hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
  323.  
  324.         query = attn.to_q(hidden_states)
  325.         key = attn.to_k(hidden_states)
  326.         value = attn.to_v(hidden_states)
  327.  
  328.         query = query.unflatten(2, (attn.heads, -1))
  329.         key = key.unflatten(2, (attn.heads, -1))
  330.         value = value.unflatten(2, (attn.heads, -1))
  331.  
  332.         query = attn.norm_q(query)
  333.         key = attn.norm_k(key)
  334.  
  335.         txt_length = encoder_hidden_states.shape[1]
  336.  
  337.         query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
  338.         key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
  339.  
  340.         hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
  341.         hidden_states = hidden_states.flatten(-2)
  342.  
  343.         hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
  344.  
  345.         return hidden_states, encoder_hidden_states
  346.  
  347.  
  348. class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
  349.     def __init__(self, embedding_dim, pooled_projection_dim):
  350.         super().__init__()
  351.  
  352.         self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
  353.         self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
  354.         self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
  355.         self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
  356.  
  357.     def forward(self, timestep, guidance, pooled_projection):
  358.         timesteps_proj = self.time_proj(timestep)
  359.         timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
  360.  
  361.         guidance_proj = self.time_proj(guidance)
  362.         guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
  363.  
  364.         time_guidance_emb = timesteps_emb + guidance_emb
  365.  
  366.         pooled_projections = self.text_embedder(pooled_projection)
  367.         conditioning = time_guidance_emb + pooled_projections
  368.  
  369.         return conditioning
  370.  
  371.  
  372. class CombinedTimestepTextProjEmbeddings(nn.Module):
  373.     def __init__(self, embedding_dim, pooled_projection_dim):
  374.         super().__init__()
  375.  
  376.         self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
  377.         self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
  378.         self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
  379.  
  380.     def forward(self, timestep, pooled_projection):
  381.         timesteps_proj = self.time_proj(timestep)
  382.         timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
  383.  
  384.         pooled_projections = self.text_embedder(pooled_projection)
  385.  
  386.         conditioning = timesteps_emb + pooled_projections
  387.  
  388.         return conditioning
  389.  
  390.  
  391. class HunyuanVideoAdaNorm(nn.Module):
  392.     def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
  393.         super().__init__()
  394.  
  395.         out_features = out_features or 2 * in_features
  396.         self.linear = nn.Linear(in_features, out_features)
  397.         self.nonlinearity = nn.SiLU()
  398.  
  399.     def forward(
  400.         self, temb: torch.Tensor
  401.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  402.         temb = self.linear(self.nonlinearity(temb))
  403.         gate_msa, gate_mlp = temb.chunk(2, dim=-1)
  404.         gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
  405.         return gate_msa, gate_mlp
  406.  
  407.  
  408. class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
  409.     def __init__(
  410.         self,
  411.         num_attention_heads: int,
  412.         attention_head_dim: int,
  413.         mlp_width_ratio: str = 4.0,
  414.         mlp_drop_rate: float = 0.0,
  415.         attention_bias: bool = True,
  416.     ) -> None:
  417.         super().__init__()
  418.  
  419.         hidden_size = num_attention_heads * attention_head_dim
  420.  
  421.         self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
  422.         self.attn = Attention(
  423.             query_dim=hidden_size,
  424.             cross_attention_dim=None,
  425.             heads=num_attention_heads,
  426.             dim_head=attention_head_dim,
  427.             bias=attention_bias,
  428.         )
  429.  
  430.         self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
  431.         self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
  432.  
  433.         self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
  434.  
  435.     def forward(
  436.         self,
  437.         hidden_states: torch.Tensor,
  438.         temb: torch.Tensor,
  439.         attention_mask: Optional[torch.Tensor] = None,
  440.     ) -> torch.Tensor:
  441.         norm_hidden_states = self.norm1(hidden_states)
  442.  
  443.         attn_output = self.attn(
  444.             hidden_states=norm_hidden_states,
  445.             encoder_hidden_states=None,
  446.             attention_mask=attention_mask,
  447.         )
  448.  
  449.         gate_msa, gate_mlp = self.norm_out(temb)
  450.         hidden_states = hidden_states + attn_output * gate_msa
  451.  
  452.         ff_output = self.ff(self.norm2(hidden_states))
  453.         hidden_states = hidden_states + ff_output * gate_mlp
  454.  
  455.         return hidden_states
  456.  
  457.  
  458. class HunyuanVideoIndividualTokenRefiner(nn.Module):
  459.     def __init__(
  460.         self,
  461.         num_attention_heads: int,
  462.         attention_head_dim: int,
  463.         num_layers: int,
  464.         mlp_width_ratio: float = 4.0,
  465.         mlp_drop_rate: float = 0.0,
  466.         attention_bias: bool = True,
  467.     ) -> None:
  468.         super().__init__()
  469.  
  470.         self.refiner_blocks = nn.ModuleList(
  471.             [
  472.                 HunyuanVideoIndividualTokenRefinerBlock(
  473.                     num_attention_heads=num_attention_heads,
  474.                     attention_head_dim=attention_head_dim,
  475.                     mlp_width_ratio=mlp_width_ratio,
  476.                     mlp_drop_rate=mlp_drop_rate,
  477.                     attention_bias=attention_bias,
  478.                 )
  479.                 for _ in range(num_layers)
  480.             ]
  481.         )
  482.  
  483.     def forward(
  484.         self,
  485.         hidden_states: torch.Tensor,
  486.         temb: torch.Tensor,
  487.         attention_mask: Optional[torch.Tensor] = None,
  488.     ) -> None:
  489.         self_attn_mask = None
  490.         if attention_mask is not None:
  491.             batch_size = attention_mask.shape[0]
  492.             seq_len = attention_mask.shape[1]
  493.             attention_mask = attention_mask.to(hidden_states.device).bool()
  494.             self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
  495.             self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
  496.             self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
  497.             self_attn_mask[:, :, :, 0] = True
  498.  
  499.         for block in self.refiner_blocks:
  500.             hidden_states = block(hidden_states, temb, self_attn_mask)
  501.  
  502.         return hidden_states
  503.  
  504.  
  505. class HunyuanVideoTokenRefiner(nn.Module):
  506.     def __init__(
  507.         self,
  508.         in_channels: int,
  509.         num_attention_heads: int,
  510.         attention_head_dim: int,
  511.         num_layers: int,
  512.         mlp_ratio: float = 4.0,
  513.         mlp_drop_rate: float = 0.0,
  514.         attention_bias: bool = True,
  515.     ) -> None:
  516.         super().__init__()
  517.  
  518.         hidden_size = num_attention_heads * attention_head_dim
  519.  
  520.         self.time_text_embed = CombinedTimestepTextProjEmbeddings(
  521.             embedding_dim=hidden_size, pooled_projection_dim=in_channels
  522.         )
  523.         self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
  524.         self.token_refiner = HunyuanVideoIndividualTokenRefiner(
  525.             num_attention_heads=num_attention_heads,
  526.             attention_head_dim=attention_head_dim,
  527.             num_layers=num_layers,
  528.             mlp_width_ratio=mlp_ratio,
  529.             mlp_drop_rate=mlp_drop_rate,
  530.             attention_bias=attention_bias,
  531.         )
  532.  
  533.     def forward(
  534.         self,
  535.         hidden_states: torch.Tensor,
  536.         timestep: torch.LongTensor,
  537.         attention_mask: Optional[torch.LongTensor] = None,
  538.     ) -> torch.Tensor:
  539.         if attention_mask is None:
  540.             pooled_projections = hidden_states.mean(dim=1)
  541.         else:
  542.             original_dtype = hidden_states.dtype
  543.             mask_float = attention_mask.float().unsqueeze(-1)
  544.             pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
  545.             pooled_projections = pooled_projections.to(original_dtype)
  546.  
  547.         temb = self.time_text_embed(timestep, pooled_projections)
  548.         hidden_states = self.proj_in(hidden_states)
  549.         hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
  550.  
  551.         return hidden_states
  552.  
  553.  
  554. class HunyuanVideoRotaryPosEmbed(nn.Module):
  555.     def __init__(self, rope_dim, theta):
  556.         super().__init__()
  557.         self.DT, self.DY, self.DX = rope_dim
  558.         self.theta = theta
  559.  
  560.     @torch.no_grad()
  561.     def get_frequency(self, dim, pos):
  562.         T, H, W = pos.shape
  563.         freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
  564.         freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
  565.         return freqs.cos(), freqs.sin()
  566.  
  567.     @torch.no_grad()
  568.     def forward_inner(self, frame_indices, height, width, device):
  569.         GT, GY, GX = torch.meshgrid(
  570.             frame_indices.to(device=device, dtype=torch.float32),
  571.             torch.arange(0, height, device=device, dtype=torch.float32),
  572.             torch.arange(0, width, device=device, dtype=torch.float32),
  573.             indexing="ij"
  574.         )
  575.  
  576.         FCT, FST = self.get_frequency(self.DT, GT)
  577.         FCY, FSY = self.get_frequency(self.DY, GY)
  578.         FCX, FSX = self.get_frequency(self.DX, GX)
  579.  
  580.         result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
  581.  
  582.         return result.to(device)
  583.  
  584.     @torch.no_grad()
  585.     def forward(self, frame_indices, height, width, device):
  586.         frame_indices = frame_indices.unbind(0)
  587.         results = [self.forward_inner(f, height, width, device) for f in frame_indices]
  588.         results = torch.stack(results, dim=0)
  589.         return results
  590.  
  591.  
  592. class AdaLayerNormZero(nn.Module):
  593.     def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
  594.         super().__init__()
  595.         self.silu = nn.SiLU()
  596.         self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
  597.         if norm_type == "layer_norm":
  598.             self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
  599.         else:
  600.             raise ValueError(f"unknown norm_type {norm_type}")
  601.  
  602.     def forward(
  603.         self,
  604.         x: torch.Tensor,
  605.         emb: Optional[torch.Tensor] = None,
  606.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  607.         emb = emb.unsqueeze(-2)
  608.         emb = self.linear(self.silu(emb))
  609.         shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
  610.         x = self.norm(x) * (1 + scale_msa) + shift_msa
  611.         return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
  612.  
  613.  
  614. class AdaLayerNormZeroSingle(nn.Module):
  615.     def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
  616.         super().__init__()
  617.  
  618.         self.silu = nn.SiLU()
  619.         self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
  620.         if norm_type == "layer_norm":
  621.             self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
  622.         else:
  623.             raise ValueError(f"unknown norm_type {norm_type}")
  624.  
  625.     def forward(
  626.         self,
  627.         x: torch.Tensor,
  628.         emb: Optional[torch.Tensor] = None,
  629.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  630.         emb = emb.unsqueeze(-2)
  631.         emb = self.linear(self.silu(emb))
  632.         shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
  633.         x = self.norm(x) * (1 + scale_msa) + shift_msa
  634.         return x, gate_msa
  635.  
  636.  
  637. class AdaLayerNormContinuous(nn.Module):
  638.     def __init__(
  639.         self,
  640.         embedding_dim: int,
  641.         conditioning_embedding_dim: int,
  642.         elementwise_affine=True,
  643.         eps=1e-5,
  644.         bias=True,
  645.         norm_type="layer_norm",
  646.     ):
  647.         super().__init__()
  648.         self.silu = nn.SiLU()
  649.         self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
  650.         if norm_type == "layer_norm":
  651.             self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
  652.         else:
  653.             raise ValueError(f"unknown norm_type {norm_type}")
  654.  
  655.     def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
  656.         emb = emb.unsqueeze(-2)
  657.         emb = self.linear(self.silu(emb))
  658.         scale, shift = emb.chunk(2, dim=-1)
  659.         x = self.norm(x) * (1 + scale) + shift
  660.         return x
  661.  
  662.  
  663. class HunyuanVideoSingleTransformerBlock(nn.Module):
  664.     def __init__(
  665.         self,
  666.         num_attention_heads: int,
  667.         attention_head_dim: int,
  668.         mlp_ratio: float = 4.0,
  669.         qk_norm: str = "rms_norm",
  670.     ) -> None:
  671.         super().__init__()
  672.  
  673.         hidden_size = num_attention_heads * attention_head_dim
  674.         mlp_dim = int(hidden_size * mlp_ratio)
  675.  
  676.         self.attn = Attention(
  677.             query_dim=hidden_size,
  678.             cross_attention_dim=None,
  679.             dim_head=attention_head_dim,
  680.             heads=num_attention_heads,
  681.             out_dim=hidden_size,
  682.             bias=True,
  683.             processor=HunyuanAttnProcessorFlashAttnSingle(),
  684.             qk_norm=qk_norm,
  685.             eps=1e-6,
  686.             pre_only=True,
  687.         )
  688.  
  689.         self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
  690.         self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
  691.         self.act_mlp = nn.GELU(approximate="tanh")
  692.         self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
  693.  
  694.     def forward(
  695.         self,
  696.         hidden_states: torch.Tensor,
  697.         encoder_hidden_states: torch.Tensor,
  698.         temb: torch.Tensor,
  699.         attention_mask: Optional[torch.Tensor] = None,
  700.         image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  701.     ) -> torch.Tensor:
  702.         text_seq_length = encoder_hidden_states.shape[1]
  703.         hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
  704.  
  705.         residual = hidden_states
  706.  
  707.         # 1. Input normalization
  708.         norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
  709.         mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
  710.  
  711.         norm_hidden_states, norm_encoder_hidden_states = (
  712.             norm_hidden_states[:, :-text_seq_length, :],
  713.             norm_hidden_states[:, -text_seq_length:, :],
  714.         )
  715.  
  716.         # 2. Attention
  717.         attn_output, context_attn_output = self.attn(
  718.             hidden_states=norm_hidden_states,
  719.             encoder_hidden_states=norm_encoder_hidden_states,
  720.             attention_mask=attention_mask,
  721.             image_rotary_emb=image_rotary_emb,
  722.         )
  723.         attn_output = torch.cat([attn_output, context_attn_output], dim=1)
  724.  
  725.         # 3. Modulation and residual connection
  726.         hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
  727.         hidden_states = gate * self.proj_out(hidden_states)
  728.         hidden_states = hidden_states + residual
  729.  
  730.         hidden_states, encoder_hidden_states = (
  731.             hidden_states[:, :-text_seq_length, :],
  732.             hidden_states[:, -text_seq_length:, :],
  733.         )
  734.         return hidden_states, encoder_hidden_states
  735.  
  736.  
  737. class HunyuanVideoTransformerBlock(nn.Module):
  738.     def __init__(
  739.         self,
  740.         num_attention_heads: int,
  741.         attention_head_dim: int,
  742.         mlp_ratio: float,
  743.         qk_norm: str = "rms_norm",
  744.     ) -> None:
  745.         super().__init__()
  746.  
  747.         hidden_size = num_attention_heads * attention_head_dim
  748.  
  749.         self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
  750.         self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
  751.  
  752.         self.attn = Attention(
  753.             query_dim=hidden_size,
  754.             cross_attention_dim=None,
  755.             added_kv_proj_dim=hidden_size,
  756.             dim_head=attention_head_dim,
  757.             heads=num_attention_heads,
  758.             out_dim=hidden_size,
  759.             context_pre_only=False,
  760.             bias=True,
  761.             processor=HunyuanAttnProcessorFlashAttnDouble(),
  762.             qk_norm=qk_norm,
  763.             eps=1e-6,
  764.         )
  765.  
  766.         self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  767.         self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
  768.  
  769.         self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  770.         self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
  771.  
  772.     def forward(
  773.         self,
  774.         hidden_states: torch.Tensor,
  775.         encoder_hidden_states: torch.Tensor,
  776.         temb: torch.Tensor,
  777.         attention_mask: Optional[torch.Tensor] = None,
  778.         freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  779.     ) -> Tuple[torch.Tensor, torch.Tensor]:
  780.         # 1. Input normalization
  781.         norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
  782.         norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
  783.  
  784.         # 2. Joint attention
  785.         attn_output, context_attn_output = self.attn(
  786.             hidden_states=norm_hidden_states,
  787.             encoder_hidden_states=norm_encoder_hidden_states,
  788.             attention_mask=attention_mask,
  789.             image_rotary_emb=freqs_cis,
  790.         )
  791.  
  792.         # 3. Modulation and residual connection
  793.         hidden_states = hidden_states + attn_output * gate_msa
  794.         encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
  795.  
  796.         norm_hidden_states = self.norm2(hidden_states)
  797.         norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
  798.  
  799.         norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
  800.         norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
  801.  
  802.         # 4. Feed-forward
  803.         ff_output = self.ff(norm_hidden_states)
  804.         context_ff_output = self.ff_context(norm_encoder_hidden_states)
  805.  
  806.         hidden_states = hidden_states + gate_mlp * ff_output
  807.         encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
  808.  
  809.         return hidden_states, encoder_hidden_states
  810.  
  811.  
  812. class ClipVisionProjection(nn.Module):
  813.     def __init__(self, in_channels, out_channels):
  814.         super().__init__()
  815.         self.up = nn.Linear(in_channels, out_channels * 3)
  816.         self.down = nn.Linear(out_channels * 3, out_channels)
  817.  
  818.     def forward(self, x):
  819.         projected_x = self.down(nn.functional.silu(self.up(x)))
  820.         return projected_x
  821.  
  822.  
  823. class HunyuanVideoPatchEmbed(nn.Module):
  824.     def __init__(self, patch_size, in_chans, embed_dim):
  825.         super().__init__()
  826.         self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  827.  
  828.  
  829. class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
  830.     def __init__(self, inner_dim):
  831.         super().__init__()
  832.         self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
  833.         self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
  834.         self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
  835.  
  836.     @torch.no_grad()
  837.     def initialize_weight_from_another_conv3d(self, another_layer):
  838.         weight = another_layer.weight.detach().clone()
  839.         bias = another_layer.bias.detach().clone()
  840.  
  841.         sd = {
  842.             'proj.weight': weight.clone(),
  843.             'proj.bias': bias.clone(),
  844.             'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
  845.             'proj_2x.bias': bias.clone(),
  846.             'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
  847.             'proj_4x.bias': bias.clone(),
  848.         }
  849.  
  850.         sd = {k: v.clone() for k, v in sd.items()}
  851.  
  852.         self.load_state_dict(sd)
  853.         return
  854.  
  855.  
  856. class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
  857.     @register_to_config
  858.     def __init__(
  859.         self,
  860.         in_channels: int = 16,
  861.         out_channels: int = 16,
  862.         num_attention_heads: int = 24,
  863.         attention_head_dim: int = 128,
  864.         num_layers: int = 20,
  865.         num_single_layers: int = 40,
  866.         num_refiner_layers: int = 2,
  867.         mlp_ratio: float = 4.0,
  868.         patch_size: int = 2,
  869.         patch_size_t: int = 1,
  870.         qk_norm: str = "rms_norm",
  871.         guidance_embeds: bool = True,
  872.         text_embed_dim: int = 4096,
  873.         pooled_projection_dim: int = 768,
  874.         rope_theta: float = 256.0,
  875.         rope_axes_dim: Tuple[int] = (16, 56, 56),
  876.         has_image_proj=False,
  877.         image_proj_dim=1152,
  878.         has_clean_x_embedder=False,
  879.     ) -> None:
  880.         super().__init__()
  881.  
  882.         inner_dim = num_attention_heads * attention_head_dim
  883.         out_channels = out_channels or in_channels
  884.  
  885.         # 1. Latent and condition embedders
  886.         self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
  887.         self.context_embedder = HunyuanVideoTokenRefiner(
  888.             text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
  889.         )
  890.         self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
  891.  
  892.         self.clean_x_embedder = None
  893.         self.image_projection = None
  894.  
  895.         # 2. RoPE
  896.         self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
  897.  
  898.         # 3. Dual stream transformer blocks
  899.         self.transformer_blocks = nn.ModuleList(
  900.             [
  901.                 HunyuanVideoTransformerBlock(
  902.                     num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
  903.                 )
  904.                 for _ in range(num_layers)
  905.             ]
  906.         )
  907.  
  908.         # 4. Single stream transformer blocks
  909.         self.single_transformer_blocks = nn.ModuleList(
  910.             [
  911.                 HunyuanVideoSingleTransformerBlock(
  912.                     num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
  913.                 )
  914.                 for _ in range(num_single_layers)
  915.             ]
  916.         )
  917.  
  918.         # 5. Output projection
  919.         self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
  920.         self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
  921.  
  922.         self.inner_dim = inner_dim
  923.         self.use_gradient_checkpointing = False
  924.         self.enable_teacache = False
  925.  
  926.         if has_image_proj:
  927.             self.install_image_projection(image_proj_dim)
  928.  
  929.         if has_clean_x_embedder:
  930.             self.install_clean_x_embedder()
  931.  
  932.         self.high_quality_fp32_output_for_inference = False
  933.  
  934.     def install_image_projection(self, in_channels):
  935.         self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
  936.         self.config['has_image_proj'] = True
  937.         self.config['image_proj_dim'] = in_channels
  938.  
  939.     def install_clean_x_embedder(self):
  940.         self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
  941.         self.config['has_clean_x_embedder'] = True
  942.  
  943.     def enable_gradient_checkpointing(self):
  944.         self.use_gradient_checkpointing = True
  945.         print('self.use_gradient_checkpointing = True')
  946.  
  947.     def disable_gradient_checkpointing(self):
  948.         self.use_gradient_checkpointing = False
  949.         print('self.use_gradient_checkpointing = False')
  950.  
  951.     def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
  952.         self.enable_teacache = enable_teacache
  953.         self.cnt = 0
  954.         self.num_steps = num_steps
  955.         self.rel_l1_thresh = rel_l1_thresh  # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
  956.         self.accumulated_rel_l1_distance = 0
  957.         self.previous_modulated_input = None
  958.         self.previous_residual = None
  959.         self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
  960.  
  961.     def gradient_checkpointing_method(self, block, *args):
  962.         if self.use_gradient_checkpointing:
  963.             result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
  964.         else:
  965.             result = block(*args)
  966.         return result
  967.  
  968.     def process_input_hidden_states(
  969.             self,
  970.             latents, latent_indices=None,
  971.             clean_latents=None, clean_latent_indices=None,
  972.             clean_latents_2x=None, clean_latent_2x_indices=None,
  973.             clean_latents_4x=None, clean_latent_4x_indices=None
  974.     ):
  975.         hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
  976.         B, C, T, H, W = hidden_states.shape
  977.  
  978.         if latent_indices is None:
  979.             latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
  980.  
  981.         hidden_states = hidden_states.flatten(2).transpose(1, 2)
  982.  
  983.         rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
  984.         rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
  985.  
  986.         if clean_latents is not None and clean_latent_indices is not None:
  987.             clean_latents = clean_latents.to(hidden_states)
  988.             clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
  989.             clean_latents = clean_latents.flatten(2).transpose(1, 2)
  990.  
  991.             clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
  992.             clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
  993.  
  994.             hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
  995.             rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
  996.  
  997.         if clean_latents_2x is not None and clean_latent_2x_indices is not None:
  998.             clean_latents_2x = clean_latents_2x.to(hidden_states)
  999.             clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
  1000.             clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
  1001.             clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
  1002.  
  1003.             clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
  1004.             clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
  1005.             clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
  1006.             clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
  1007.  
  1008.             hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
  1009.             rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
  1010.  
  1011.         if clean_latents_4x is not None and clean_latent_4x_indices is not None:
  1012.             clean_latents_4x = clean_latents_4x.to(hidden_states)
  1013.             clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
  1014.             clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
  1015.             clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
  1016.  
  1017.             clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
  1018.             clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
  1019.             clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
  1020.             clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
  1021.  
  1022.             hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
  1023.             rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
  1024.  
  1025.         return hidden_states, rope_freqs
  1026.  
  1027.     def forward(
  1028.             self,
  1029.             hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
  1030.             latent_indices=None,
  1031.             clean_latents=None, clean_latent_indices=None,
  1032.             clean_latents_2x=None, clean_latent_2x_indices=None,
  1033.             clean_latents_4x=None, clean_latent_4x_indices=None,
  1034.             image_embeddings=None,
  1035.             attention_kwargs=None, return_dict=True
  1036.     ):
  1037.  
  1038.         if attention_kwargs is None:
  1039.             attention_kwargs = {}
  1040.  
  1041.         batch_size, num_channels, num_frames, height, width = hidden_states.shape
  1042.         p, p_t = self.config['patch_size'], self.config['patch_size_t']
  1043.         post_patch_num_frames = num_frames // p_t
  1044.         post_patch_height = height // p
  1045.         post_patch_width = width // p
  1046.         original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
  1047.  
  1048.         hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
  1049.  
  1050.         temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
  1051.         encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
  1052.  
  1053.         if self.image_projection is not None:
  1054.             assert image_embeddings is not None, 'You must use image embeddings!'
  1055.             extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
  1056.             extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
  1057.  
  1058.             # must cat before (not after) encoder_hidden_states, due to attn masking
  1059.             encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
  1060.             encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
  1061.  
  1062.         with torch.no_grad():
  1063.             if batch_size == 1:
  1064.                 # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
  1065.                 # If they are not same, then their impls are wrong. Ours are always the correct one.
  1066.                 text_len = encoder_attention_mask.sum().item()
  1067.                 encoder_hidden_states = encoder_hidden_states[:, :text_len]
  1068.                 attention_mask = None, None, None, None
  1069.             else:
  1070.                 img_seq_len = hidden_states.shape[1]
  1071.                 txt_seq_len = encoder_hidden_states.shape[1]
  1072.  
  1073.                 cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
  1074.                 cu_seqlens_kv = cu_seqlens_q
  1075.                 max_seqlen_q = img_seq_len + txt_seq_len
  1076.                 max_seqlen_kv = max_seqlen_q
  1077.  
  1078.                 attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
  1079.  
  1080.         if self.enable_teacache:
  1081.             modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
  1082.  
  1083.             if self.cnt == 0 or self.cnt == self.num_steps-1:
  1084.                 should_calc = True
  1085.                 self.accumulated_rel_l1_distance = 0
  1086.             else:
  1087.                 curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
  1088.                 self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
  1089.                 should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
  1090.  
  1091.                 if should_calc:
  1092.                     self.accumulated_rel_l1_distance = 0
  1093.  
  1094.             self.previous_modulated_input = modulated_inp
  1095.             self.cnt += 1
  1096.  
  1097.             if self.cnt == self.num_steps:
  1098.                 self.cnt = 0
  1099.  
  1100.             if not should_calc:
  1101.                 hidden_states = hidden_states + self.previous_residual
  1102.             else:
  1103.                 ori_hidden_states = hidden_states.clone()
  1104.  
  1105.                 for block_id, block in enumerate(self.transformer_blocks):
  1106.                     hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
  1107.                         block,
  1108.                         hidden_states,
  1109.                         encoder_hidden_states,
  1110.                         temb,
  1111.                         attention_mask,
  1112.                         rope_freqs
  1113.                     )
  1114.  
  1115.                 for block_id, block in enumerate(self.single_transformer_blocks):
  1116.                     hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
  1117.                         block,
  1118.                         hidden_states,
  1119.                         encoder_hidden_states,
  1120.                         temb,
  1121.                         attention_mask,
  1122.                         rope_freqs
  1123.                     )
  1124.  
  1125.                 self.previous_residual = hidden_states - ori_hidden_states
  1126.         else:
  1127.             for block_id, block in enumerate(self.transformer_blocks):
  1128.                 hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
  1129.                     block,
  1130.                     hidden_states,
  1131.                     encoder_hidden_states,
  1132.                     temb,
  1133.                     attention_mask,
  1134.                     rope_freqs
  1135.                 )
  1136.  
  1137.             for block_id, block in enumerate(self.single_transformer_blocks):
  1138.                 hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
  1139.                     block,
  1140.                     hidden_states,
  1141.                     encoder_hidden_states,
  1142.                     temb,
  1143.                     attention_mask,
  1144.                     rope_freqs
  1145.                 )
  1146.  
  1147.         hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
  1148.  
  1149.         hidden_states = hidden_states[:, -original_context_length:, :]
  1150.  
  1151.         if self.high_quality_fp32_output_for_inference:
  1152.             hidden_states = hidden_states.to(dtype=torch.float32)
  1153.             if self.proj_out.weight.dtype != torch.float32:
  1154.                 self.proj_out.to(dtype=torch.float32)
  1155.  
  1156.         hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
  1157.  
  1158.         hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
  1159.                                          t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
  1160.                                          pt=p_t, ph=p, pw=p)
  1161.  
  1162.         if return_dict:
  1163.             return Transformer2DModelOutput(sample=hidden_states)
  1164.  
  1165.         return hidden_states,
  1166.  
Advertisement
Add Comment
Please, Sign In to add comment