Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
- query_states = (
- self.q_proj(hidden_states)
- .view(bsz, q_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- )
- # [bsz, nh, t, hd]
- q_position_ids = position_ids
- k_position_ids = position_ids
- kv_seq_len = q_len
- if past_key_value is not None:
- # reuse k, v, self_attention
- # key_states = torch.cat([past_key_value[0], key_states], dim=2)
- # value_states = torch.cat([past_key_value[1], value_states], dim=2)
- k_position_ids = torch.cat([past_key_value[3], k_position_ids], dim=1)
- kv_seq_len += past_key_value[4]
- hidden_states = torch.cat([past_key_value[2], hidden_states], dim=1)
- past_key_value = (
- (
- torch.ones(1, 1, 1, 1),
- torch.ones(1),
- hidden_states,
- k_position_ids,
- kv_seq_len,
- )
- if use_cache
- else None
- )
- key_states = (
- self.k_proj(hidden_states)
- .view(bsz, kv_seq_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- )
- value_states = (
- self.v_proj(hidden_states)
- .view(bsz, kv_seq_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- )
- cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len)
- (
- query_states,
- key_states,
- ) = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, q_position_ids, k_position_ids
- )
- attn_weights = torch.matmul(
- query_states, key_states.transpose(2, 3)
- ) / math.sqrt(self.head_dim)
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = torch.max(
- attn_weights,
- torch.tensor(
- torch.finfo(attn_weights.dtype).min, device=attn_weights.device
- ),
- )
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(
- attn_weights,
- dim=-1,
- dtype=torch.float16,
- ).to(query_states.dtype)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights, past_key_value
Advertisement
Add Comment
Please, Sign In to add comment