Advertisement
Guest User

Untitled

a guest
Oct 30th, 2022
179
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 24.80 KB | None | 0 0
  1. import math
  2. from typing import Optional
  3.  
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn
  7.  
  8. try:
  9. import xformers.ops
  10. MEM_EFFICIENT_ATTN = True
  11. except ImportError:
  12. MEM_EFFICIENT_ATTN = False
  13.  
  14. class AttentionBlock(nn.Module):
  15. """
  16. An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
  17. to the N-d case.
  18. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
  19. Uses three q, k, v linear layers to compute attention.
  20.  
  21. Parameters:
  22. channels (:obj:`int`): The number of channels in the input and output.
  23. num_head_channels (:obj:`int`, *optional*):
  24. The number of channels in each head. If None, then `num_heads` = 1.
  25. num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
  26. rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
  27. eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
  28. """
  29.  
  30. def __init__(
  31. self,
  32. channels: int,
  33. num_head_channels: Optional[int] = None,
  34. num_groups: int = 32,
  35. rescale_output_factor: float = 1.0,
  36. eps: float = 1e-5,
  37. ):
  38. super().__init__()
  39. self.channels = channels
  40.  
  41. self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
  42. self.num_head_size = num_head_channels
  43. self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
  44.  
  45. # define q,k,v as linear layers
  46. self.query = nn.Linear(channels, channels)
  47. self.key = nn.Linear(channels, channels)
  48. self.value = nn.Linear(channels, channels)
  49.  
  50. self.rescale_output_factor = rescale_output_factor
  51. self.proj_attn = nn.Linear(channels, channels, 1)
  52.  
  53. def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
  54. new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
  55. # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
  56. new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
  57. return new_projection
  58.  
  59. def forward(self, hidden_states):
  60. residual = hidden_states
  61. batch, channel, height, width = hidden_states.shape
  62.  
  63. # norm
  64. hidden_states = self.group_norm(hidden_states)
  65.  
  66. hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
  67.  
  68. # proj to q, k, v
  69. query_proj = self.query(hidden_states)
  70. key_proj = self.key(hidden_states)
  71. value_proj = self.value(hidden_states)
  72.  
  73. # transpose
  74. query_states = self.transpose_for_scores(query_proj)
  75. key_states = self.transpose_for_scores(key_proj)
  76. value_states = self.transpose_for_scores(value_proj)
  77.  
  78. # get scores
  79. scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
  80. attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
  81. attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
  82.  
  83. # compute attention output
  84. hidden_states = torch.matmul(attention_probs, value_states)
  85.  
  86. hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
  87. new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
  88. hidden_states = hidden_states.view(new_hidden_states_shape)
  89.  
  90. # compute next hidden_states
  91. hidden_states = self.proj_attn(hidden_states)
  92. hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
  93.  
  94. # res connect and rescale
  95. hidden_states = (hidden_states + residual) / self.rescale_output_factor
  96. return hidden_states
  97.  
  98.  
  99. class SpatialTransformer(nn.Module):
  100. """
  101. Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
  102. standard transformer action. Finally, reshape to image.
  103.  
  104. Parameters:
  105. in_channels (:obj:`int`): The number of channels in the input and output.
  106. n_heads (:obj:`int`): The number of heads to use for multi-head attention.
  107. d_head (:obj:`int`): The number of channels in each head.
  108. depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
  109. dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
  110. context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
  111. """
  112.  
  113. def __init__(
  114. self,
  115. in_channels: int,
  116. n_heads: int,
  117. d_head: int,
  118. depth: int = 1,
  119. dropout: float = 0.0,
  120. num_groups: int = 32,
  121. context_dim: Optional[int] = None,
  122. ):
  123. super().__init__()
  124. self.n_heads = n_heads
  125. self.d_head = d_head
  126. self.in_channels = in_channels
  127. inner_dim = n_heads * d_head
  128. self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
  129.  
  130. self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
  131.  
  132. self.transformer_blocks = nn.ModuleList(
  133. [
  134. BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
  135. for d in range(depth)
  136. ]
  137. )
  138.  
  139. self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
  140.  
  141. def _set_attention_slice(self, slice_size):
  142. for block in self.transformer_blocks:
  143. block._set_attention_slice(slice_size)
  144.  
  145. def forward(self, hidden_states, context=None):
  146. # note: if no context is given, cross-attention defaults to self-attention
  147. batch, channel, height, weight = hidden_states.shape
  148. residual = hidden_states
  149. hidden_states = self.norm(hidden_states)
  150. hidden_states = self.proj_in(hidden_states)
  151. inner_dim = hidden_states.shape[1]
  152. hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
  153. for block in self.transformer_blocks:
  154. hidden_states = block(hidden_states, context=context)
  155. hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
  156. hidden_states = self.proj_out(hidden_states)
  157. return hidden_states + residual
  158.  
  159.  
  160. class BasicTransformerBlock(nn.Module):
  161. r"""
  162. A basic Transformer block.
  163.  
  164. Parameters:
  165. dim (:obj:`int`): The number of channels in the input and output.
  166. n_heads (:obj:`int`): The number of heads to use for multi-head attention.
  167. d_head (:obj:`int`): The number of channels in each head.
  168. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
  169. context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
  170. gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
  171. checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
  172. """
  173.  
  174. def __init__(
  175. self,
  176. dim: int,
  177. n_heads: int,
  178. d_head: int,
  179. dropout=0.0,
  180. context_dim: Optional[int] = None,
  181. gated_ff: bool = True,
  182. checkpoint: bool = True,
  183. ):
  184. super().__init__()
  185. self.attn1 = CrossAttention(
  186. query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
  187. ) # is a self-attention
  188. self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
  189. self.attn2 = CrossAttention(
  190. query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
  191. ) # is self-attn if context is none
  192. self.norm1 = nn.LayerNorm(dim)
  193. self.norm2 = nn.LayerNorm(dim)
  194. self.norm3 = nn.LayerNorm(dim)
  195. self.checkpoint = checkpoint
  196.  
  197. def _set_attention_slice(self, slice_size):
  198. self.attn1._slice_size = slice_size
  199. self.attn2._slice_size = slice_size
  200.  
  201. def forward(self, hidden_states, context=None):
  202. hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
  203. hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
  204. hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
  205. hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
  206. return hidden_states
  207.  
  208.  
  209.  
  210.  
  211. # $$$ FlashAttentionを使うCrossAttention
  212. # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
  213.  
  214. # constants
  215.  
  216. EPSILON = 1e-6
  217.  
  218. # helper functions
  219.  
  220. def exists(val):
  221. return val is not None
  222.  
  223. def default(val, d):
  224. return val if exists(val) else d
  225.  
  226. # flash attention forwards and backwards
  227.  
  228. # https://arxiv.org/abs/2205.14135
  229.  
  230. from torch.autograd.function import Function
  231. class FlashAttentionFunction(Function):
  232. @staticmethod
  233. @torch.no_grad()
  234. def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
  235. """ Algorithm 2 in the paper """
  236.  
  237. device = q.device
  238. dtype = q.dtype
  239. max_neg_value = -torch.finfo(q.dtype).max
  240. qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
  241.  
  242. o = torch.zeros_like(q)
  243. all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device = device)
  244. all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device = device)
  245.  
  246. scale = (q.shape[-1] ** -0.5)
  247.  
  248. if not exists(mask):
  249. mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
  250. else:
  251. mask = rearrange(mask, 'b n -> b 1 1 n')
  252. mask = mask.split(q_bucket_size, dim = -1)
  253.  
  254. row_splits = zip(
  255. q.split(q_bucket_size, dim = -2),
  256. o.split(q_bucket_size, dim = -2),
  257. mask,
  258. all_row_sums.split(q_bucket_size, dim = -2),
  259. all_row_maxes.split(q_bucket_size, dim = -2),
  260. )
  261.  
  262. for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
  263. q_start_index = ind * q_bucket_size - qk_len_diff
  264.  
  265. col_splits = zip(
  266. k.split(k_bucket_size, dim = -2),
  267. v.split(k_bucket_size, dim = -2),
  268. )
  269.  
  270. for k_ind, (kc, vc) in enumerate(col_splits):
  271. k_start_index = k_ind * k_bucket_size
  272.  
  273. attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
  274.  
  275. if exists(row_mask):
  276. attn_weights.masked_fill_(~row_mask, max_neg_value)
  277.  
  278. if causal and q_start_index < (k_start_index + k_bucket_size - 1):
  279. causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
  280. attn_weights.masked_fill_(causal_mask, max_neg_value)
  281.  
  282. block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
  283. attn_weights -= block_row_maxes
  284. exp_weights = torch.exp(attn_weights)
  285.  
  286. if exists(row_mask):
  287. exp_weights.masked_fill_(~row_mask, 0.)
  288.  
  289. block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
  290.  
  291. new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
  292.  
  293. exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
  294.  
  295. exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
  296. exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
  297.  
  298. new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
  299.  
  300. oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
  301.  
  302. row_maxes.copy_(new_row_maxes)
  303. row_sums.copy_(new_row_sums)
  304.  
  305. ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
  306. ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
  307.  
  308. return o
  309.  
  310. @staticmethod
  311. @torch.no_grad()
  312. def backward(ctx, do):
  313. """ Algorithm 4 in the paper """
  314.  
  315. causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
  316. q, k, v, o, l, m = ctx.saved_tensors
  317.  
  318. device = q.device
  319.  
  320. max_neg_value = -torch.finfo(q.dtype).max
  321. qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
  322.  
  323. dq = torch.zeros_like(q)
  324. dk = torch.zeros_like(k)
  325. dv = torch.zeros_like(v)
  326.  
  327. row_splits = zip(
  328. q.split(q_bucket_size, dim = -2),
  329. o.split(q_bucket_size, dim = -2),
  330. do.split(q_bucket_size, dim = -2),
  331. mask,
  332. l.split(q_bucket_size, dim = -2),
  333. m.split(q_bucket_size, dim = -2),
  334. dq.split(q_bucket_size, dim = -2)
  335. )
  336.  
  337. for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
  338. q_start_index = ind * q_bucket_size - qk_len_diff
  339.  
  340. col_splits = zip(
  341. k.split(k_bucket_size, dim = -2),
  342. v.split(k_bucket_size, dim = -2),
  343. dk.split(k_bucket_size, dim = -2),
  344. dv.split(k_bucket_size, dim = -2),
  345. )
  346.  
  347. for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
  348. k_start_index = k_ind * k_bucket_size
  349.  
  350. attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
  351.  
  352. if causal and q_start_index < (k_start_index + k_bucket_size - 1):
  353. causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
  354. attn_weights.masked_fill_(causal_mask, max_neg_value)
  355.  
  356. exp_attn_weights = torch.exp(attn_weights - mc)
  357.  
  358. if exists(row_mask):
  359. exp_attn_weights.masked_fill_(~row_mask, 0.)
  360.  
  361. p = exp_attn_weights / lc
  362.  
  363. dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
  364. dp = einsum('... i d, ... j d -> ... i j', doc, vc)
  365.  
  366. D = (doc * oc).sum(dim = -1, keepdims = True)
  367. ds = p * scale * (dp - D)
  368.  
  369. dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
  370. dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
  371.  
  372. dqc.add_(dq_chunk)
  373. dkc.add_(dk_chunk)
  374. dvc.add_(dv_chunk)
  375.  
  376. return dq, dk, dv, None, None, None, None
  377.  
  378.  
  379. # based on https://github.com/gammagec/Dreambooth-SD-optimized/blob/main/ldm/modules/attention.py
  380. class CrossAttention(nn.Module):
  381. def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
  382. super().__init__()
  383. # print("$$$ CrossAttention with FlashAttention")
  384. inner_dim = dim_head * heads
  385. context_dim = default(context_dim, query_dim)
  386.  
  387. self.scale = dim_head ** -0.5
  388. self.heads = heads
  389.  
  390. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  391. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  392. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  393.  
  394. self.to_out = nn.Sequential(
  395. nn.Linear(inner_dim, query_dim),
  396. nn.Dropout(dropout)
  397. )
  398.  
  399. self.flash_func = FlashAttentionFunction
  400. self.q_bucket_size = 512 # 256 # # 128 # 小さいほうが遅くなるが省メモリのはず
  401. self.k_bucket_size = 1024 # 512 # # 256
  402.  
  403. def forward(self, x, context=None, mask=None):
  404. # with autocast('cuda'):
  405. h = self.heads
  406.  
  407. q = self.to_q(x)
  408. context = default(context, x)
  409. context = context.to(x.dtype)
  410. k = self.to_k(context)
  411. v = self.to_v(context)
  412. del context, x
  413.  
  414. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
  415.  
  416. out = self.flash_func.apply(q, k, v, mask, False, self.q_bucket_size, self.k_bucket_size)
  417.  
  418. out = rearrange(out, 'b h n d -> b n (h d)')
  419.  
  420. return self.to_out(out)
  421.  
  422.  
  423. # $$$ 省メモリなCrossAttention
  424. # copy from https://github.com/gammagec/Dreambooth-SD-optimized 一部変更
  425. from torch import nn, einsum, autocast
  426. from einops import rearrange, repeat
  427.  
  428. def exists(val):
  429. return val is not None
  430.  
  431. def default(val, d):
  432. return val if exists(val) else d
  433.  
  434. class CrossAttentionDBSDOPT(nn.Module):
  435. def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
  436. # print("$$$ memory efficient CrossAttention")
  437. super().__init__()
  438. inner_dim = dim_head * heads
  439. context_dim = default(context_dim, query_dim)
  440.  
  441. self.scale = dim_head ** -0.5
  442. self.heads = heads
  443.  
  444. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  445. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  446. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  447.  
  448. self.to_out = nn.Sequential(
  449. nn.Linear(inner_dim, query_dim),
  450. nn.Dropout(dropout)
  451. )
  452.  
  453. def forward(self, x, context=None, mask=None):
  454. #with autocast('cuda'):
  455. h = self.heads
  456.  
  457. q = self.to_q(x)
  458. context = default(context, x)
  459. k = self.to_k(context)
  460. v = self.to_v(context)
  461. del context, x
  462.  
  463. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  464.  
  465. r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], dtype=q.dtype, device=q.device)
  466.  
  467. # valid values for steps = 2,4,8,16,32,64
  468. # higher steps is slower but less memory usage
  469. # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920
  470. # speed seems to be impacted more on 30x series cards
  471. steps = 16
  472. slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1]
  473. for i in range(0, q.shape[1], slice_size):
  474. end = i + slice_size
  475. s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
  476. s1 *= self.scale
  477. s2 = s1.softmax(dim=-1)
  478. del s1
  479. r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
  480. del s2
  481. r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
  482. del r1
  483.  
  484. return self.to_out(r2)
  485.  
  486.  
  487. class CrossAttentionOrg(nn.Module):
  488. r"""
  489. A cross attention layer.
  490.  
  491. Parameters:
  492. query_dim (:obj:`int`): The number of channels in the query.
  493. context_dim (:obj:`int`, *optional*):
  494. The number of channels in the context. If not given, defaults to `query_dim`.
  495. heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
  496. dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
  497. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
  498. """
  499.  
  500. def __init__(
  501. self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
  502. ):
  503. super().__init__()
  504. inner_dim = dim_head * heads
  505. context_dim = context_dim if context_dim is not None else query_dim
  506.  
  507. self.scale = dim_head**-0.5
  508. self.heads = heads
  509. self.dim_head = dim_head
  510. # for slice_size > 0 the attention score computation
  511. # is split across the batch axis to save memory
  512. # You can set slice_size with `set_attention_slice`
  513. self._slice_size = None
  514.  
  515. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  516. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  517. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  518.  
  519. self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
  520.  
  521. def reshape_heads_to_batch_dim(self, tensor):
  522. batch_size, seq_len, dim = tensor.shape
  523. head_size = self.heads
  524. tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
  525. tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
  526. return tensor
  527.  
  528. def reshape_batch_dim_to_heads(self, tensor):
  529. batch_size, seq_len, dim = tensor.shape
  530. head_size = self.heads
  531. tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
  532. tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
  533. return tensor
  534.  
  535. def forward(self, hidden_states, context=None, mask=None):
  536. batch_size, sequence_length, _ = hidden_states.shape
  537. query = self.to_q(hidden_states)
  538. context = context if context is not None else hidden_states
  539. key = self.to_k(context)
  540. value = self.to_v(context)
  541.  
  542. dim = query.shape[-1]
  543.  
  544. query = self.reshape_heads_to_batch_dim(query)
  545. key = self.reshape_heads_to_batch_dim(key)
  546. value = self.reshape_heads_to_batch_dim(value)
  547.  
  548. # TODO(PVP) - mask is currently never used. Remember to re-implement when used
  549. # attention, what we cannot get enough of
  550. if MEM_EFFICIENT_ATTN:
  551. query = query.contiguous()
  552. key = key.contiguous()
  553. value = value.contiguous()
  554. hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
  555. elif self._slice_size is None or query.shape[0] // self._slice_size == 1:
  556. hidden_states = self._attention(query, key, value)
  557. else:
  558. hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
  559. hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
  560. return self.to_out(hidden_states)
  561.  
  562. def _attention(self, query, key, value):
  563. # TODO: use baddbmm for better performance
  564. attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
  565. attention_probs = attention_scores.softmax(dim=-1)
  566. # compute attention output
  567. hidden_states = torch.matmul(attention_probs, value)
  568. return hidden_states
  569.  
  570. def _sliced_attention(self, query, key, value, sequence_length, dim):
  571. batch_size_attention = query.shape[0]
  572. hidden_states = torch.zeros(
  573. (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
  574. )
  575. slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
  576. for i in range(hidden_states.shape[0] // slice_size):
  577. start_idx = i * slice_size
  578. end_idx = (i + 1) * slice_size
  579. attn_slice = (
  580. torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
  581. ) # TODO: use baddbmm for better performance
  582. attn_slice = attn_slice.softmax(dim=-1)
  583. attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
  584.  
  585. hidden_states[start_idx:end_idx] = attn_slice
  586.  
  587. return hidden_states
  588.  
  589.  
  590. class FeedForward(nn.Module):
  591. r"""
  592. A feed-forward layer.
  593.  
  594. Parameters:
  595. dim (:obj:`int`): The number of channels in the input.
  596. dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
  597. mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
  598. glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
  599. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
  600. """
  601.  
  602. def __init__(
  603. self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
  604. ):
  605. super().__init__()
  606. inner_dim = int(dim * mult)
  607. dim_out = dim_out if dim_out is not None else dim
  608. project_in = GEGLU(dim, inner_dim)
  609.  
  610. self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
  611.  
  612. def forward(self, hidden_states):
  613. return self.net(hidden_states)
  614.  
  615.  
  616. # feedforward
  617. class GEGLU(nn.Module):
  618. r"""
  619. A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
  620.  
  621. Parameters:
  622. dim_in (:obj:`int`): The number of channels in the input.
  623. dim_out (:obj:`int`): The number of channels in the output.
  624. """
  625.  
  626. def __init__(self, dim_in: int, dim_out: int):
  627. super().__init__()
  628. self.proj = nn.Linear(dim_in, dim_out * 2)
  629.  
  630. def forward(self, hidden_states):
  631. hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
  632. return hidden_states * F.gelu(gate)
  633.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement