Guest User

Untitled

a guest
Jul 25th, 2023
163
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.06 KB | None | 0 0
  1. def forward(
  2.         self,
  3.         hidden_states: torch.Tensor,
  4.         attention_mask: Optional[torch.Tensor] = None,
  5.         position_ids: Optional[torch.LongTensor] = None,
  6.         past_key_value: Optional[Tuple[torch.Tensor]] = None,
  7.         output_attentions: bool = False,
  8.         use_cache: bool = False,
  9.     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  10.         bsz, q_len, _ = hidden_states.size()
  11.  
  12.         query_states = (
  13.             self.q_proj(hidden_states)
  14.             .view(bsz, q_len, self.num_heads, self.head_dim)
  15.             .transpose(1, 2)
  16.         )
  17.  
  18.         # [bsz, nh, t, hd]
  19.  
  20.         q_position_ids = position_ids
  21.         k_position_ids = position_ids
  22.         kv_seq_len = q_len
  23.         if past_key_value is not None:
  24.             # reuse k, v, self_attention
  25.             # key_states = torch.cat([past_key_value[0], key_states], dim=2)
  26.             # value_states = torch.cat([past_key_value[1], value_states], dim=2)
  27.             k_position_ids = torch.cat([past_key_value[3], k_position_ids], dim=1)
  28.             kv_seq_len += past_key_value[4]
  29.             hidden_states = torch.cat([past_key_value[2], hidden_states], dim=1)
  30.  
  31.         past_key_value = (
  32.             (
  33.                 torch.ones(1, 1, 1, 1),
  34.                 torch.ones(1),
  35.                 hidden_states,
  36.                 k_position_ids,
  37.                 kv_seq_len,
  38.             )
  39.             if use_cache
  40.             else None
  41.         )
  42.  
  43.         key_states = (
  44.             self.k_proj(hidden_states)
  45.             .view(bsz, kv_seq_len, self.num_heads, self.head_dim)
  46.             .transpose(1, 2)
  47.         )
  48.         value_states = (
  49.             self.v_proj(hidden_states)
  50.             .view(bsz, kv_seq_len, self.num_heads, self.head_dim)
  51.             .transpose(1, 2)
  52.         )
  53.  
  54.         cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len)
  55.         (
  56.             query_states,
  57.             key_states,
  58.         ) = apply_rotary_pos_emb(
  59.             query_states, key_states, cos, sin, q_position_ids, k_position_ids
  60.         )
  61.  
  62.         attn_weights = torch.matmul(
  63.             query_states, key_states.transpose(2, 3)
  64.         ) / math.sqrt(self.head_dim)
  65.  
  66.         if attention_mask is not None:
  67.             attn_weights = attn_weights + attention_mask
  68.             attn_weights = torch.max(
  69.                 attn_weights,
  70.                 torch.tensor(
  71.                     torch.finfo(attn_weights.dtype).min, device=attn_weights.device
  72.                 ),
  73.             )
  74.  
  75.         # upcast attention to fp32
  76.         attn_weights = nn.functional.softmax(
  77.             attn_weights,
  78.             dim=-1,
  79.             dtype=torch.float16,
  80.         ).to(query_states.dtype)
  81.         attn_output = torch.matmul(attn_weights, value_states)
  82.  
  83.         attn_output = attn_output.transpose(1, 2)
  84.         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  85.  
  86.         attn_output = self.o_proj(attn_output)
  87.  
  88.         if not output_attentions:
  89.             attn_weights = None
  90.  
  91.         return attn_output, attn_weights, past_key_value
Advertisement
Add Comment
Please, Sign In to add comment