Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import math
- from typing import Optional
- import torch
- import torch.nn.functional as F
- from torch import nn
- try:
- import xformers.ops
- MEM_EFFICIENT_ATTN = True
- except ImportError:
- MEM_EFFICIENT_ATTN = False
- class AttentionBlock(nn.Module):
- """
- An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
- to the N-d case.
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- Uses three q, k, v linear layers to compute attention.
- Parameters:
- channels (:obj:`int`): The number of channels in the input and output.
- num_head_channels (:obj:`int`, *optional*):
- The number of channels in each head. If None, then `num_heads` = 1.
- num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
- rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
- eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
- """
- def __init__(
- self,
- channels: int,
- num_head_channels: Optional[int] = None,
- num_groups: int = 32,
- rescale_output_factor: float = 1.0,
- eps: float = 1e-5,
- ):
- super().__init__()
- self.channels = channels
- self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
- self.num_head_size = num_head_channels
- self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
- # define q,k,v as linear layers
- self.query = nn.Linear(channels, channels)
- self.key = nn.Linear(channels, channels)
- self.value = nn.Linear(channels, channels)
- self.rescale_output_factor = rescale_output_factor
- self.proj_attn = nn.Linear(channels, channels, 1)
- def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
- new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
- # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
- new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
- return new_projection
- def forward(self, hidden_states):
- residual = hidden_states
- batch, channel, height, width = hidden_states.shape
- # norm
- hidden_states = self.group_norm(hidden_states)
- hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
- # proj to q, k, v
- query_proj = self.query(hidden_states)
- key_proj = self.key(hidden_states)
- value_proj = self.value(hidden_states)
- # transpose
- query_states = self.transpose_for_scores(query_proj)
- key_states = self.transpose_for_scores(key_proj)
- value_states = self.transpose_for_scores(value_proj)
- # get scores
- scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
- attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
- attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
- # compute attention output
- hidden_states = torch.matmul(attention_probs, value_states)
- hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
- new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
- hidden_states = hidden_states.view(new_hidden_states_shape)
- # compute next hidden_states
- hidden_states = self.proj_attn(hidden_states)
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
- # res connect and rescale
- hidden_states = (hidden_states + residual) / self.rescale_output_factor
- return hidden_states
- class SpatialTransformer(nn.Module):
- """
- Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
- standard transformer action. Finally, reshape to image.
- Parameters:
- in_channels (:obj:`int`): The number of channels in the input and output.
- n_heads (:obj:`int`): The number of heads to use for multi-head attention.
- d_head (:obj:`int`): The number of channels in each head.
- depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
- context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
- """
- def __init__(
- self,
- in_channels: int,
- n_heads: int,
- d_head: int,
- depth: int = 1,
- dropout: float = 0.0,
- num_groups: int = 32,
- context_dim: Optional[int] = None,
- ):
- super().__init__()
- self.n_heads = n_heads
- self.d_head = d_head
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
- for d in range(depth)
- ]
- )
- self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- def _set_attention_slice(self, slice_size):
- for block in self.transformer_blocks:
- block._set_attention_slice(slice_size)
- def forward(self, hidden_states, context=None):
- # note: if no context is given, cross-attention defaults to self-attention
- batch, channel, height, weight = hidden_states.shape
- residual = hidden_states
- hidden_states = self.norm(hidden_states)
- hidden_states = self.proj_in(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
- for block in self.transformer_blocks:
- hidden_states = block(hidden_states, context=context)
- hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
- hidden_states = self.proj_out(hidden_states)
- return hidden_states + residual
- class BasicTransformerBlock(nn.Module):
- r"""
- A basic Transformer block.
- Parameters:
- dim (:obj:`int`): The number of channels in the input and output.
- n_heads (:obj:`int`): The number of heads to use for multi-head attention.
- d_head (:obj:`int`): The number of channels in each head.
- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
- context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
- gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
- checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
- """
- def __init__(
- self,
- dim: int,
- n_heads: int,
- d_head: int,
- dropout=0.0,
- context_dim: Optional[int] = None,
- gated_ff: bool = True,
- checkpoint: bool = True,
- ):
- super().__init__()
- self.attn1 = CrossAttention(
- query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
- ) # is a self-attention
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
- self.attn2 = CrossAttention(
- query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
- ) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
- self.checkpoint = checkpoint
- def _set_attention_slice(self, slice_size):
- self.attn1._slice_size = slice_size
- self.attn2._slice_size = slice_size
- def forward(self, hidden_states, context=None):
- hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
- hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
- hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
- hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
- return hidden_states
- # $$$ FlashAttentionを使うCrossAttention
- # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
- # constants
- EPSILON = 1e-6
- # helper functions
- def exists(val):
- return val is not None
- def default(val, d):
- return val if exists(val) else d
- # flash attention forwards and backwards
- # https://arxiv.org/abs/2205.14135
- from torch.autograd.function import Function
- class FlashAttentionFunction(Function):
- @staticmethod
- @torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device = device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device = device)
- scale = (q.shape[-1] ** -0.5)
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim = -1)
- row_splits = zip(
- q.split(q_bucket_size, dim = -2),
- o.split(q_bucket_size, dim = -2),
- mask,
- all_row_sums.split(q_bucket_size, dim = -2),
- all_row_maxes.split(q_bucket_size, dim = -2),
- )
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
- col_splits = zip(
- k.split(k_bucket_size, dim = -2),
- v.split(k_bucket_size, dim = -2),
- )
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
- block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
- block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
- return o
- @staticmethod
- @torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
- device = q.device
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
- row_splits = zip(
- q.split(q_bucket_size, dim = -2),
- o.split(q_bucket_size, dim = -2),
- do.split(q_bucket_size, dim = -2),
- mask,
- l.split(q_bucket_size, dim = -2),
- m.split(q_bucket_size, dim = -2),
- dq.split(q_bucket_size, dim = -2)
- )
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
- col_splits = zip(
- k.split(k_bucket_size, dim = -2),
- v.split(k_bucket_size, dim = -2),
- dk.split(k_bucket_size, dim = -2),
- dv.split(k_bucket_size, dim = -2),
- )
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
- exp_attn_weights = torch.exp(attn_weights - mc)
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
- p = exp_attn_weights / lc
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
- D = (doc * oc).sum(dim = -1, keepdims = True)
- ds = p * scale * (dp - D)
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
- return dq, dk, dv, None, None, None, None
- # based on https://github.com/gammagec/Dreambooth-SD-optimized/blob/main/ldm/modules/attention.py
- class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- # print("$$$ CrossAttention with FlashAttention")
- inner_dim = dim_head * heads
- context_dim = default(context_dim, query_dim)
- self.scale = dim_head ** -0.5
- self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
- nn.Dropout(dropout)
- )
- self.flash_func = FlashAttentionFunction
- self.q_bucket_size = 512 # 256 # # 128 # 小さいほうが遅くなるが省メモリのはず
- self.k_bucket_size = 1024 # 512 # # 256
- def forward(self, x, context=None, mask=None):
- # with autocast('cuda'):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
- out = self.flash_func.apply(q, k, v, mask, False, self.q_bucket_size, self.k_bucket_size)
- out = rearrange(out, 'b h n d -> b n (h d)')
- return self.to_out(out)
- # $$$ 省メモリなCrossAttention
- # copy from https://github.com/gammagec/Dreambooth-SD-optimized 一部変更
- from torch import nn, einsum, autocast
- from einops import rearrange, repeat
- def exists(val):
- return val is not None
- def default(val, d):
- return val if exists(val) else d
- class CrossAttentionDBSDOPT(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
- # print("$$$ memory efficient CrossAttention")
- super().__init__()
- inner_dim = dim_head * heads
- context_dim = default(context_dim, query_dim)
- self.scale = dim_head ** -0.5
- self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
- nn.Dropout(dropout)
- )
- def forward(self, x, context=None, mask=None):
- #with autocast('cuda'):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], dtype=q.dtype, device=q.device)
- # valid values for steps = 2,4,8,16,32,64
- # higher steps is slower but less memory usage
- # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920
- # speed seems to be impacted more on 30x series cards
- steps = 16
- slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
- s1 *= self.scale
- s2 = s1.softmax(dim=-1)
- del s1
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
- return self.to_out(r2)
- class CrossAttentionOrg(nn.Module):
- r"""
- A cross attention layer.
- Parameters:
- query_dim (:obj:`int`): The number of channels in the query.
- context_dim (:obj:`int`, *optional*):
- The number of channels in the context. If not given, defaults to `query_dim`.
- heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
- dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
- """
- def __init__(
- self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
- ):
- super().__init__()
- inner_dim = dim_head * heads
- context_dim = context_dim if context_dim is not None else query_dim
- self.scale = dim_head**-0.5
- self.heads = heads
- self.dim_head = dim_head
- # for slice_size > 0 the attention score computation
- # is split across the batch axis to save memory
- # You can set slice_size with `set_attention_slice`
- self._slice_size = None
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
- def reshape_heads_to_batch_dim(self, tensor):
- batch_size, seq_len, dim = tensor.shape
- head_size = self.heads
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
- return tensor
- def reshape_batch_dim_to_heads(self, tensor):
- batch_size, seq_len, dim = tensor.shape
- head_size = self.heads
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
- return tensor
- def forward(self, hidden_states, context=None, mask=None):
- batch_size, sequence_length, _ = hidden_states.shape
- query = self.to_q(hidden_states)
- context = context if context is not None else hidden_states
- key = self.to_k(context)
- value = self.to_v(context)
- dim = query.shape[-1]
- query = self.reshape_heads_to_batch_dim(query)
- key = self.reshape_heads_to_batch_dim(key)
- value = self.reshape_heads_to_batch_dim(value)
- # TODO(PVP) - mask is currently never used. Remember to re-implement when used
- # attention, what we cannot get enough of
- if MEM_EFFICIENT_ATTN:
- query = query.contiguous()
- key = key.contiguous()
- value = value.contiguous()
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
- elif self._slice_size is None or query.shape[0] // self._slice_size == 1:
- hidden_states = self._attention(query, key, value)
- else:
- hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
- return self.to_out(hidden_states)
- def _attention(self, query, key, value):
- # TODO: use baddbmm for better performance
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
- attention_probs = attention_scores.softmax(dim=-1)
- # compute attention output
- hidden_states = torch.matmul(attention_probs, value)
- return hidden_states
- def _sliced_attention(self, query, key, value, sequence_length, dim):
- batch_size_attention = query.shape[0]
- hidden_states = torch.zeros(
- (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
- )
- slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
- for i in range(hidden_states.shape[0] // slice_size):
- start_idx = i * slice_size
- end_idx = (i + 1) * slice_size
- attn_slice = (
- torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
- ) # TODO: use baddbmm for better performance
- attn_slice = attn_slice.softmax(dim=-1)
- attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
- hidden_states[start_idx:end_idx] = attn_slice
- return hidden_states
- class FeedForward(nn.Module):
- r"""
- A feed-forward layer.
- Parameters:
- dim (:obj:`int`): The number of channels in the input.
- dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
- mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
- glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
- """
- def __init__(
- self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
- ):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = dim_out if dim_out is not None else dim
- project_in = GEGLU(dim, inner_dim)
- self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
- def forward(self, hidden_states):
- return self.net(hidden_states)
- # feedforward
- class GEGLU(nn.Module):
- r"""
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
- Parameters:
- dim_in (:obj:`int`): The number of channels in the input.
- dim_out (:obj:`int`): The number of channels in the output.
- """
- def __init__(self, dim_in: int, dim_out: int):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
- def forward(self, hidden_states):
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
- return hidden_states * F.gelu(gate)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement