Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from typing import Any, Dict, List, Optional, Tuple, Union
- import torch
- import einops
- import torch.nn as nn
- import numpy as np
- from diffusers.loaders import FromOriginalModelMixin
- from diffusers.configuration_utils import ConfigMixin, register_to_config
- from diffusers.loaders import PeftAdapterMixin
- from diffusers.utils import logging
- from diffusers.models.attention import FeedForward
- from diffusers.models.attention_processor import Attention
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
- from diffusers.models.modeling_utils import ModelMixin
- from diffusers_helper.dit_common import LayerNorm
- from diffusers_helper.utils import zero_module
- enabled_backends = []
- if torch.backends.cuda.flash_sdp_enabled():
- enabled_backends.append("flash")
- if torch.backends.cuda.math_sdp_enabled():
- enabled_backends.append("math")
- if torch.backends.cuda.mem_efficient_sdp_enabled():
- enabled_backends.append("mem_efficient")
- if torch.backends.cuda.cudnn_sdp_enabled():
- enabled_backends.append("cudnn")
- print("Currently enabled native sdp backends:", enabled_backends)
- try:
- # raise NotImplementedError
- from xformers.ops import memory_efficient_attention as xformers_attn_func
- print('Xformers is installed!')
- except:
- print('Xformers is not installed!')
- xformers_attn_func = None
- try:
- # raise NotImplementedError
- from flash_attn import flash_attn_varlen_func, flash_attn_func
- print('Flash Attn is installed!')
- except:
- print('Flash Attn is not installed!')
- flash_attn_varlen_func = None
- flash_attn_func = None
- try:
- # raise NotImplementedError
- from sageattention import sageattn_varlen, sageattn
- print('Sage Attn is installed!')
- except:
- print('Sage Attn is not installed!')
- sageattn_varlen = None
- sageattn = None
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
- def pad_for_3d_conv(x, kernel_size):
- b, c, t, h, w = x.shape
- pt, ph, pw = kernel_size
- pad_t = (pt - (t % pt)) % pt
- pad_h = (ph - (h % ph)) % ph
- pad_w = (pw - (w % pw)) % pw
- return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
- def center_down_sample_3d(x, kernel_size):
- # pt, ph, pw = kernel_size
- # cp = (pt * ph * pw) // 2
- # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
- # xc = xp[cp]
- # return xc
- return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
- def get_cu_seqlens(text_mask, img_len):
- batch_size = text_mask.shape[0]
- text_len = text_mask.sum(dim=1)
- max_len = text_mask.shape[1] + img_len
- cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
- for i in range(batch_size):
- s = text_len[i] + img_len
- s1 = i * max_len + s
- s2 = (i + 1) * max_len
- cu_seqlens[2 * i + 1] = s1
- cu_seqlens[2 * i + 2] = s2
- return cu_seqlens
- def apply_rotary_emb_transposed(x, freqs_cis):
- cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
- x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
- out = x.float() * cos + x_rotated.float() * sin
- out = out.to(x)
- return out
- def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
- if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
- if sageattn is not None:
- x = sageattn(q, k, v, tensor_layout='NHD')
- return x
- # if flash_attn_func is not None:
- # x = flash_attn_func(q, k, v)
- # return x
- if xformers_attn_func is not None:
- # x = xformers_attn_func(q, k, v)
- # Cast to float16 before calling xformers as bf16 might not be supported
- q_fp16 = q.to(torch.float16)
- k_fp16 = k.to(torch.float16)
- v_fp16 = v.to(torch.float16)
- try:
- x = xformers_attn_func(q_fp16, k_fp16, v_fp16)
- x = x.to(q.dtype) # Cast back to original dtype
- return x
- except NotImplementedError:
- # If xformers fails with float16, fall through to SDPA
- print("xFormers failed with float16, falling back to SDPA.")
- pass # Let SDPA handle it
- # Fallback to SDPA, ensure float16 or float32 as bfloat16 might not be supported natively on Turing
- q_fallback = q.to(torch.float16) # Try float16 first for SDPA
- k_fallback = k.to(torch.float16)
- v_fallback = v.to(torch.float16)
- try:
- # Use PyTorch's built-in SDPA
- x = torch.nn.functional.scaled_dot_product_attention(q_fallback.transpose(1, 2), k_fallback.transpose(1, 2), v_fallback.transpose(1, 2)).transpose(1, 2)
- x = x.to(q.dtype) # Cast back to original dtype
- return x
- # x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
- # return x
- except Exception as e_sdpa_fp16:
- print(f"SDPA failed with float16: {e_sdpa_fp16}. Trying float32.")
- # If float16 fails for SDPA (less likely but possible), try float32
- q_fallback_fp32 = q.to(torch.float32)
- k_fallback_fp32 = k.to(torch.float32)
- v_fallback_fp32 = v.to(torch.float32)
- try:
- x = torch.nn.functional.scaled_dot_product_attention(q_fallback_fp32.transpose(1, 2), k_fallback_fp32.transpose(1, 2), v_fallback_fp32.transpose(1, 2)).transpose(1, 2)
- x = x.to(q.dtype) # Cast back to original dtype
- return x
- except Exception as e_sdpa_fp32:
- print(f"SDPA also failed with float32: {e_sdpa_fp32}. Raising original error.")
- raise e_sdpa_fp16 # Re-raise the float16 error if float32 also fails
- batch_size = q.shape[0]
- # The view operations might need to happen *before* casting if cu_seqlens are involved
- # Let's keep the original view logic for now
- q_orig_shape = q.shape
- k_orig_shape = k.shape
- v_orig_shape = v.shape
- q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
- k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
- v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
- if sageattn_varlen is not None:
- # x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
- # elif flash_attn_varlen_func is not None:
- # x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
- # Assuming sageattn handles dtypes correctly or needs its own casting
- try:
- x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
- except Exception as e_sage:
- print(f"Sage Attn failed: {e_sage}. Falling back.")
- # Fallback logic needed if sage fails
- x = None # Placeholder
- else:
- # raise NotImplementedError('No Attn Installed!')
- # x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
- x = None # Initialize x if sageattn is not used
- #elif flash_attn_varlen_func is not None: # Commented out
- # x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
- if x is None and xformers_attn_func is not None: # Check if sage failed or wasn't used, and xformers is available
- # Cast to float16 before calling xformers for variable length
- q_fp16 = q.to(torch.float16)
- k_fp16 = k.to(torch.float16)
- v_fp16 = v.to(torch.float16)
- try:
- # xFormers memory_efficient_attention does not directly support cu_seqlens.
- # The original code used flash_attn_varlen_func for this.
- # We might need to pad/unpad manually or use a different xFormers interface if one exists for varlen.
- # For now, trying the standard call, which might be incorrect for packed sequences.
- print("Warning: Calling standard xformers attention for variable length sequence. This might be incorrect. Padding/unpadding or a dedicated varlen function might be needed.")
- x = xformers_attn_func(q_fp16, k_fp16, v_fp16) # This is likely incorrect for varlen!
- x = x.to(q.dtype) # Cast back
- except NotImplementedError:
- print("xFormers (standard) failed for varlen with float16, falling back to SDPA.")
- x = None # Mark xFormers as failed
- except Exception as e_xformers_varlen:
- print(f"xFormers (standard) encountered an error for varlen: {e_xformers_varlen}. Falling back to SDPA.")
- x = None # Mark xFormers as failed
- if x is None: # If sage/flash/xformers all failed or were unavailable for varlen
- print("No specialized attention backend worked for varlen, falling back to basic SDPA (potentially incorrect for packed sequences).")
- q_fallback = q.to(torch.float16) # Use float16 for fallback
- k_fallback = k.to(torch.float16)
- v_fallback = v.to(torch.float16)
- # Basic SDPA call - this does NOT respect cu_seqlens and will likely produce wrong results for packed attention.
- # A correct implementation would require iterating through batches/sequences or finding a SDPA variant that handles cu_seqlens.
- try:
- x_list = []
- current_idx_q = 0
- current_idx_kv = 0
- # Manually iterate based on cu_seqlens (this is slow and inefficient but more correct than a single SDPA call)
- for i in range(batch_size):
- start_idx_q = cu_seqlens_q[2*i].item()
- end_idx_q = cu_seqlens_q[2*i+1].item()
- start_idx_kv = cu_seqlens_kv[2*i].item() # Assuming cu_seqlens_kv matches q structure here
- end_idx_kv = cu_seqlens_kv[2*i+1].item()
- q_i = q_fallback[start_idx_q:end_idx_q].unsqueeze(0).transpose(1, 2) # Shape: (1, num_heads, seq_len_q, head_dim) -> (1, seq_len_q, num_heads, head_dim) -> (1, num_heads, seq_len_q, head_dim) for SDPA
- k_i = k_fallback[start_idx_kv:end_idx_kv].unsqueeze(0).transpose(1, 2)
- v_i = v_fallback[start_idx_kv:end_idx_kv].unsqueeze(0).transpose(1, 2)
- # Reshape for SDPA: (batch=1, seq_len, num_heads, head_dim) -> (batch=1, num_heads, seq_len, head_dim)
- q_i = q_i.view(1, -1, q_orig_shape[-2], q_orig_shape[-1]).transpose(1, 2)
- k_i = k_i.view(1, -1, k_orig_shape[-2], k_orig_shape[-1]).transpose(1, 2)
- v_i = v_i.view(1, -1, v_orig_shape[-2], v_orig_shape[-1]).transpose(1, 2)
- out_i = torch.nn.functional.scaled_dot_product_attention(q_i, k_i, v_i).transpose(1, 2) # Output: (1, num_heads, seq_len_q, head_dim) -> (1, seq_len_q, num_heads, head_dim)
- # Reshape back to match original 'flattened' structure before view
- out_i = out_i.reshape(1, -1, q_orig_shape[-2] * q_orig_shape[-1]) # (1, seq_len_q, hidden_dim)
- x_list.append(out_i.squeeze(0))
- # Pad shorter sequences to max_seqlen_q before cat
- padded_x_list = []
- for x_i in x_list:
- pad_len = max_seqlen_q - x_i.shape[0]
- if pad_len > 0:
- # Pad with zeros, adjust padding dims if needed
- padded_x_i = torch.nn.functional.pad(x_i, (0, 0, 0, pad_len)) # Pads the sequence length dim (dim 0)
- padded_x_list.append(padded_x_i)
- else:
- padded_x_list.append(x_i)
- x = torch.stack(padded_x_list, dim=0) # Shape: (batch_size, max_seqlen_q, hidden_dim)
- x = x.to(q.dtype) # Cast back
- except Exception as e_sdpa_varlen:
- print(f"Manual SDPA loop for varlen failed: {e_sdpa_varlen}")
- # If even the manual loop fails, we are stuck. Raise error.
- raise NotImplementedError("Could not execute attention with any backend for variable sequence lengths.")
- # Reshape x back to (batch_size, max_seqlen_q, num_heads, head_dim) before final view?
- # The SDPA fallback already outputs in (batch_size, seq_len, hidden_dim) format which matches the expected format before the final view.
- # Ensure xformers/sage output is also in this format if they were used. (Assuming they are)
- # Final reshape expects input shape (batch_size, max_seqlen_q, hidden_dim)
- x = x.view(batch_size, max_seqlen_q, *q_orig_shape[2:]) # Use original shape's head dim etc.
- return x
- class HunyuanAttnProcessorFlashAttnDouble:
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- query = query.unflatten(2, (attn.heads, -1))
- key = key.unflatten(2, (attn.heads, -1))
- value = value.unflatten(2, (attn.heads, -1))
- query = attn.norm_q(query)
- key = attn.norm_k(key)
- query = apply_rotary_emb_transposed(query, image_rotary_emb)
- key = apply_rotary_emb_transposed(key, image_rotary_emb)
- encoder_query = attn.add_q_proj(encoder_hidden_states)
- encoder_key = attn.add_k_proj(encoder_hidden_states)
- encoder_value = attn.add_v_proj(encoder_hidden_states)
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
- encoder_query = attn.norm_added_q(encoder_query)
- encoder_key = attn.norm_added_k(encoder_key)
- query = torch.cat([query, encoder_query], dim=1)
- key = torch.cat([key, encoder_key], dim=1)
- value = torch.cat([value, encoder_value], dim=1)
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
- hidden_states = hidden_states.flatten(-2)
- txt_length = encoder_hidden_states.shape[1]
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
- hidden_states = attn.to_out[0](hidden_states)
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
- return hidden_states, encoder_hidden_states
- class HunyuanAttnProcessorFlashAttnSingle:
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- query = query.unflatten(2, (attn.heads, -1))
- key = key.unflatten(2, (attn.heads, -1))
- value = value.unflatten(2, (attn.heads, -1))
- query = attn.norm_q(query)
- key = attn.norm_k(key)
- txt_length = encoder_hidden_states.shape[1]
- query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
- key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
- hidden_states = hidden_states.flatten(-2)
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
- return hidden_states, encoder_hidden_states
- class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
- def __init__(self, embedding_dim, pooled_projection_dim):
- super().__init__()
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
- self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
- def forward(self, timestep, guidance, pooled_projection):
- timesteps_proj = self.time_proj(timestep)
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
- guidance_proj = self.time_proj(guidance)
- guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
- time_guidance_emb = timesteps_emb + guidance_emb
- pooled_projections = self.text_embedder(pooled_projection)
- conditioning = time_guidance_emb + pooled_projections
- return conditioning
- class CombinedTimestepTextProjEmbeddings(nn.Module):
- def __init__(self, embedding_dim, pooled_projection_dim):
- super().__init__()
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
- def forward(self, timestep, pooled_projection):
- timesteps_proj = self.time_proj(timestep)
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
- pooled_projections = self.text_embedder(pooled_projection)
- conditioning = timesteps_emb + pooled_projections
- return conditioning
- class HunyuanVideoAdaNorm(nn.Module):
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
- super().__init__()
- out_features = out_features or 2 * in_features
- self.linear = nn.Linear(in_features, out_features)
- self.nonlinearity = nn.SiLU()
- def forward(
- self, temb: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- temb = self.linear(self.nonlinearity(temb))
- gate_msa, gate_mlp = temb.chunk(2, dim=-1)
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
- return gate_msa, gate_mlp
- class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
- def __init__(
- self,
- num_attention_heads: int,
- attention_head_dim: int,
- mlp_width_ratio: str = 4.0,
- mlp_drop_rate: float = 0.0,
- attention_bias: bool = True,
- ) -> None:
- super().__init__()
- hidden_size = num_attention_heads * attention_head_dim
- self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
- self.attn = Attention(
- query_dim=hidden_size,
- cross_attention_dim=None,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- bias=attention_bias,
- )
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
- self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
- self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- norm_hidden_states = self.norm1(hidden_states)
- attn_output = self.attn(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=None,
- attention_mask=attention_mask,
- )
- gate_msa, gate_mlp = self.norm_out(temb)
- hidden_states = hidden_states + attn_output * gate_msa
- ff_output = self.ff(self.norm2(hidden_states))
- hidden_states = hidden_states + ff_output * gate_mlp
- return hidden_states
- class HunyuanVideoIndividualTokenRefiner(nn.Module):
- def __init__(
- self,
- num_attention_heads: int,
- attention_head_dim: int,
- num_layers: int,
- mlp_width_ratio: float = 4.0,
- mlp_drop_rate: float = 0.0,
- attention_bias: bool = True,
- ) -> None:
- super().__init__()
- self.refiner_blocks = nn.ModuleList(
- [
- HunyuanVideoIndividualTokenRefinerBlock(
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- mlp_width_ratio=mlp_width_ratio,
- mlp_drop_rate=mlp_drop_rate,
- attention_bias=attention_bias,
- )
- for _ in range(num_layers)
- ]
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> None:
- self_attn_mask = None
- if attention_mask is not None:
- batch_size = attention_mask.shape[0]
- seq_len = attention_mask.shape[1]
- attention_mask = attention_mask.to(hidden_states.device).bool()
- self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
- self_attn_mask[:, :, :, 0] = True
- for block in self.refiner_blocks:
- hidden_states = block(hidden_states, temb, self_attn_mask)
- return hidden_states
- class HunyuanVideoTokenRefiner(nn.Module):
- def __init__(
- self,
- in_channels: int,
- num_attention_heads: int,
- attention_head_dim: int,
- num_layers: int,
- mlp_ratio: float = 4.0,
- mlp_drop_rate: float = 0.0,
- attention_bias: bool = True,
- ) -> None:
- super().__init__()
- hidden_size = num_attention_heads * attention_head_dim
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
- embedding_dim=hidden_size, pooled_projection_dim=in_channels
- )
- self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
- self.token_refiner = HunyuanVideoIndividualTokenRefiner(
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- num_layers=num_layers,
- mlp_width_ratio=mlp_ratio,
- mlp_drop_rate=mlp_drop_rate,
- attention_bias=attention_bias,
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- timestep: torch.LongTensor,
- attention_mask: Optional[torch.LongTensor] = None,
- ) -> torch.Tensor:
- if attention_mask is None:
- pooled_projections = hidden_states.mean(dim=1)
- else:
- original_dtype = hidden_states.dtype
- mask_float = attention_mask.float().unsqueeze(-1)
- pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
- pooled_projections = pooled_projections.to(original_dtype)
- temb = self.time_text_embed(timestep, pooled_projections)
- hidden_states = self.proj_in(hidden_states)
- hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
- return hidden_states
- class HunyuanVideoRotaryPosEmbed(nn.Module):
- def __init__(self, rope_dim, theta):
- super().__init__()
- self.DT, self.DY, self.DX = rope_dim
- self.theta = theta
- @torch.no_grad()
- def get_frequency(self, dim, pos):
- T, H, W = pos.shape
- freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
- freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
- return freqs.cos(), freqs.sin()
- @torch.no_grad()
- def forward_inner(self, frame_indices, height, width, device):
- GT, GY, GX = torch.meshgrid(
- frame_indices.to(device=device, dtype=torch.float32),
- torch.arange(0, height, device=device, dtype=torch.float32),
- torch.arange(0, width, device=device, dtype=torch.float32),
- indexing="ij"
- )
- FCT, FST = self.get_frequency(self.DT, GT)
- FCY, FSY = self.get_frequency(self.DY, GY)
- FCX, FSX = self.get_frequency(self.DX, GX)
- result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
- return result.to(device)
- @torch.no_grad()
- def forward(self, frame_indices, height, width, device):
- frame_indices = frame_indices.unbind(0)
- results = [self.forward_inner(f, height, width, device) for f in frame_indices]
- results = torch.stack(results, dim=0)
- return results
- class AdaLayerNormZero(nn.Module):
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
- super().__init__()
- self.silu = nn.SiLU()
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
- if norm_type == "layer_norm":
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
- else:
- raise ValueError(f"unknown norm_type {norm_type}")
- def forward(
- self,
- x: torch.Tensor,
- emb: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- emb = emb.unsqueeze(-2)
- emb = self.linear(self.silu(emb))
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
- x = self.norm(x) * (1 + scale_msa) + shift_msa
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
- class AdaLayerNormZeroSingle(nn.Module):
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
- super().__init__()
- self.silu = nn.SiLU()
- self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
- if norm_type == "layer_norm":
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
- else:
- raise ValueError(f"unknown norm_type {norm_type}")
- def forward(
- self,
- x: torch.Tensor,
- emb: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- emb = emb.unsqueeze(-2)
- emb = self.linear(self.silu(emb))
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
- x = self.norm(x) * (1 + scale_msa) + shift_msa
- return x, gate_msa
- class AdaLayerNormContinuous(nn.Module):
- def __init__(
- self,
- embedding_dim: int,
- conditioning_embedding_dim: int,
- elementwise_affine=True,
- eps=1e-5,
- bias=True,
- norm_type="layer_norm",
- ):
- super().__init__()
- self.silu = nn.SiLU()
- self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
- if norm_type == "layer_norm":
- self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
- else:
- raise ValueError(f"unknown norm_type {norm_type}")
- def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
- emb = emb.unsqueeze(-2)
- emb = self.linear(self.silu(emb))
- scale, shift = emb.chunk(2, dim=-1)
- x = self.norm(x) * (1 + scale) + shift
- return x
- class HunyuanVideoSingleTransformerBlock(nn.Module):
- def __init__(
- self,
- num_attention_heads: int,
- attention_head_dim: int,
- mlp_ratio: float = 4.0,
- qk_norm: str = "rms_norm",
- ) -> None:
- super().__init__()
- hidden_size = num_attention_heads * attention_head_dim
- mlp_dim = int(hidden_size * mlp_ratio)
- self.attn = Attention(
- query_dim=hidden_size,
- cross_attention_dim=None,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- out_dim=hidden_size,
- bias=True,
- processor=HunyuanAttnProcessorFlashAttnSingle(),
- qk_norm=qk_norm,
- eps=1e-6,
- pre_only=True,
- )
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
- self.act_mlp = nn.GELU(approximate="tanh")
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.shape[1]
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
- residual = hidden_states
- # 1. Input normalization
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
- norm_hidden_states, norm_encoder_hidden_states = (
- norm_hidden_states[:, :-text_seq_length, :],
- norm_hidden_states[:, -text_seq_length:, :],
- )
- # 2. Attention
- attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- attention_mask=attention_mask,
- image_rotary_emb=image_rotary_emb,
- )
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
- # 3. Modulation and residual connection
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
- hidden_states = gate * self.proj_out(hidden_states)
- hidden_states = hidden_states + residual
- hidden_states, encoder_hidden_states = (
- hidden_states[:, :-text_seq_length, :],
- hidden_states[:, -text_seq_length:, :],
- )
- return hidden_states, encoder_hidden_states
- class HunyuanVideoTransformerBlock(nn.Module):
- def __init__(
- self,
- num_attention_heads: int,
- attention_head_dim: int,
- mlp_ratio: float,
- qk_norm: str = "rms_norm",
- ) -> None:
- super().__init__()
- hidden_size = num_attention_heads * attention_head_dim
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
- self.attn = Attention(
- query_dim=hidden_size,
- cross_attention_dim=None,
- added_kv_proj_dim=hidden_size,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- out_dim=hidden_size,
- context_pre_only=False,
- bias=True,
- processor=HunyuanAttnProcessorFlashAttnDouble(),
- qk_norm=qk_norm,
- eps=1e-6,
- )
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
- self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # 1. Input normalization
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
- # 2. Joint attention
- attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=norm_encoder_hidden_states,
- attention_mask=attention_mask,
- image_rotary_emb=freqs_cis,
- )
- # 3. Modulation and residual connection
- hidden_states = hidden_states + attn_output * gate_msa
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
- norm_hidden_states = self.norm2(hidden_states)
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
- # 4. Feed-forward
- ff_output = self.ff(norm_hidden_states)
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
- hidden_states = hidden_states + gate_mlp * ff_output
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
- return hidden_states, encoder_hidden_states
- class ClipVisionProjection(nn.Module):
- def __init__(self, in_channels, out_channels):
- super().__init__()
- self.up = nn.Linear(in_channels, out_channels * 3)
- self.down = nn.Linear(out_channels * 3, out_channels)
- def forward(self, x):
- projected_x = self.down(nn.functional.silu(self.up(x)))
- return projected_x
- class HunyuanVideoPatchEmbed(nn.Module):
- def __init__(self, patch_size, in_chans, embed_dim):
- super().__init__()
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
- def __init__(self, inner_dim):
- super().__init__()
- self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
- self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
- self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
- @torch.no_grad()
- def initialize_weight_from_another_conv3d(self, another_layer):
- weight = another_layer.weight.detach().clone()
- bias = another_layer.bias.detach().clone()
- sd = {
- 'proj.weight': weight.clone(),
- 'proj.bias': bias.clone(),
- 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
- 'proj_2x.bias': bias.clone(),
- 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
- 'proj_4x.bias': bias.clone(),
- }
- sd = {k: v.clone() for k, v in sd.items()}
- self.load_state_dict(sd)
- return
- class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
- @register_to_config
- def __init__(
- self,
- in_channels: int = 16,
- out_channels: int = 16,
- num_attention_heads: int = 24,
- attention_head_dim: int = 128,
- num_layers: int = 20,
- num_single_layers: int = 40,
- num_refiner_layers: int = 2,
- mlp_ratio: float = 4.0,
- patch_size: int = 2,
- patch_size_t: int = 1,
- qk_norm: str = "rms_norm",
- guidance_embeds: bool = True,
- text_embed_dim: int = 4096,
- pooled_projection_dim: int = 768,
- rope_theta: float = 256.0,
- rope_axes_dim: Tuple[int] = (16, 56, 56),
- has_image_proj=False,
- image_proj_dim=1152,
- has_clean_x_embedder=False,
- ) -> None:
- super().__init__()
- inner_dim = num_attention_heads * attention_head_dim
- out_channels = out_channels or in_channels
- # 1. Latent and condition embedders
- self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
- self.context_embedder = HunyuanVideoTokenRefiner(
- text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
- )
- self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
- self.clean_x_embedder = None
- self.image_projection = None
- # 2. RoPE
- self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
- # 3. Dual stream transformer blocks
- self.transformer_blocks = nn.ModuleList(
- [
- HunyuanVideoTransformerBlock(
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
- )
- for _ in range(num_layers)
- ]
- )
- # 4. Single stream transformer blocks
- self.single_transformer_blocks = nn.ModuleList(
- [
- HunyuanVideoSingleTransformerBlock(
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
- )
- for _ in range(num_single_layers)
- ]
- )
- # 5. Output projection
- self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
- self.inner_dim = inner_dim
- self.use_gradient_checkpointing = False
- self.enable_teacache = False
- if has_image_proj:
- self.install_image_projection(image_proj_dim)
- if has_clean_x_embedder:
- self.install_clean_x_embedder()
- self.high_quality_fp32_output_for_inference = False
- def install_image_projection(self, in_channels):
- self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
- self.config['has_image_proj'] = True
- self.config['image_proj_dim'] = in_channels
- def install_clean_x_embedder(self):
- self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
- self.config['has_clean_x_embedder'] = True
- def enable_gradient_checkpointing(self):
- self.use_gradient_checkpointing = True
- print('self.use_gradient_checkpointing = True')
- def disable_gradient_checkpointing(self):
- self.use_gradient_checkpointing = False
- print('self.use_gradient_checkpointing = False')
- def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
- self.enable_teacache = enable_teacache
- self.cnt = 0
- self.num_steps = num_steps
- self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
- self.accumulated_rel_l1_distance = 0
- self.previous_modulated_input = None
- self.previous_residual = None
- self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
- def gradient_checkpointing_method(self, block, *args):
- if self.use_gradient_checkpointing:
- result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
- else:
- result = block(*args)
- return result
- def process_input_hidden_states(
- self,
- latents, latent_indices=None,
- clean_latents=None, clean_latent_indices=None,
- clean_latents_2x=None, clean_latent_2x_indices=None,
- clean_latents_4x=None, clean_latent_4x_indices=None
- ):
- hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
- B, C, T, H, W = hidden_states.shape
- if latent_indices is None:
- latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
- hidden_states = hidden_states.flatten(2).transpose(1, 2)
- rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
- rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
- if clean_latents is not None and clean_latent_indices is not None:
- clean_latents = clean_latents.to(hidden_states)
- clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
- clean_latents = clean_latents.flatten(2).transpose(1, 2)
- clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
- clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
- hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
- rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
- if clean_latents_2x is not None and clean_latent_2x_indices is not None:
- clean_latents_2x = clean_latents_2x.to(hidden_states)
- clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
- clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
- clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
- clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
- clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
- clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
- clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
- hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
- rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
- if clean_latents_4x is not None and clean_latent_4x_indices is not None:
- clean_latents_4x = clean_latents_4x.to(hidden_states)
- clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
- clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
- clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
- clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
- clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
- clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
- clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
- hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
- rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
- return hidden_states, rope_freqs
- def forward(
- self,
- hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
- latent_indices=None,
- clean_latents=None, clean_latent_indices=None,
- clean_latents_2x=None, clean_latent_2x_indices=None,
- clean_latents_4x=None, clean_latent_4x_indices=None,
- image_embeddings=None,
- attention_kwargs=None, return_dict=True
- ):
- if attention_kwargs is None:
- attention_kwargs = {}
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
- p, p_t = self.config['patch_size'], self.config['patch_size_t']
- post_patch_num_frames = num_frames // p_t
- post_patch_height = height // p
- post_patch_width = width // p
- original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
- hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
- temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
- encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
- if self.image_projection is not None:
- assert image_embeddings is not None, 'You must use image embeddings!'
- extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
- extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
- # must cat before (not after) encoder_hidden_states, due to attn masking
- encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
- encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
- with torch.no_grad():
- if batch_size == 1:
- # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
- # If they are not same, then their impls are wrong. Ours are always the correct one.
- text_len = encoder_attention_mask.sum().item()
- encoder_hidden_states = encoder_hidden_states[:, :text_len]
- attention_mask = None, None, None, None
- else:
- img_seq_len = hidden_states.shape[1]
- txt_seq_len = encoder_hidden_states.shape[1]
- cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
- cu_seqlens_kv = cu_seqlens_q
- max_seqlen_q = img_seq_len + txt_seq_len
- max_seqlen_kv = max_seqlen_q
- attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
- if self.enable_teacache:
- modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
- if self.cnt == 0 or self.cnt == self.num_steps-1:
- should_calc = True
- self.accumulated_rel_l1_distance = 0
- else:
- curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
- self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
- should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
- if should_calc:
- self.accumulated_rel_l1_distance = 0
- self.previous_modulated_input = modulated_inp
- self.cnt += 1
- if self.cnt == self.num_steps:
- self.cnt = 0
- if not should_calc:
- hidden_states = hidden_states + self.previous_residual
- else:
- ori_hidden_states = hidden_states.clone()
- for block_id, block in enumerate(self.transformer_blocks):
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
- block,
- hidden_states,
- encoder_hidden_states,
- temb,
- attention_mask,
- rope_freqs
- )
- for block_id, block in enumerate(self.single_transformer_blocks):
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
- block,
- hidden_states,
- encoder_hidden_states,
- temb,
- attention_mask,
- rope_freqs
- )
- self.previous_residual = hidden_states - ori_hidden_states
- else:
- for block_id, block in enumerate(self.transformer_blocks):
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
- block,
- hidden_states,
- encoder_hidden_states,
- temb,
- attention_mask,
- rope_freqs
- )
- for block_id, block in enumerate(self.single_transformer_blocks):
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
- block,
- hidden_states,
- encoder_hidden_states,
- temb,
- attention_mask,
- rope_freqs
- )
- hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
- hidden_states = hidden_states[:, -original_context_length:, :]
- if self.high_quality_fp32_output_for_inference:
- hidden_states = hidden_states.to(dtype=torch.float32)
- if self.proj_out.weight.dtype != torch.float32:
- self.proj_out.to(dtype=torch.float32)
- hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
- hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
- t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
- pt=p_t, ph=p, pw=p)
- if return_dict:
- return Transformer2DModelOutput(sample=hidden_states)
- return hidden_states,
Advertisement
Add Comment
Please, Sign In to add comment