Guest User

Untitled

a guest
Jul 15th, 2021
255
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.05 KB | None | 0 0
  1. class DocModel(nn.Module):
  2.     def __init__(    
  3.         self,
  4.         model: nn.Module,
  5.         max_doc_length: int,
  6.         max_position: int,
  7.     ) -> None:
  8.         super().__init__()
  9.         self.model = model
  10.         self.norm = nn.LayerNorm(DIM)
  11.         self.projection = nn.Linear(DIM, DIM, bias=False)
  12.         self.max_doc_length = max_doc_length
  13.         self.max_position = max_position
  14.        
  15.         for p in self.model.parameters():
  16.             p.requires_grad = False
  17.        
  18.     def forward(
  19.         self,
  20.         input_ids: torch.LongTensor,
  21.         attention_mask: torch.LongTensor,
  22.     ) -> torch.FloatTensor:
  23.         batch_size = len(input_ids)
  24.         out = self.model(
  25.             input_ids=input_ids.view(-1, self.max_doc_length),
  26.             attention_mask=attention_mask.view(-1, self.max_doc_length),
  27.         )[0][:, 0, :] # get CLS token
  28.         out = self.projection(self.norm(F.relu(out)))
  29.         out = out.view(batch_size, self.max_position, -1)
  30.         return F.normalize(out, dim=-1)
  31.  
  32.  
  33. class QueryModel(nn.Module):
  34.     def __init__(self, model: nn.Module) -> None:
  35.         super().__init__()
  36.         self.model = model
  37.         self.norm = nn.LayerNorm(DIM)
  38.         self.projection = nn.Linear(DIM, DIM, bias=False)
  39.        
  40.         for p in self.model.parameters():
  41.             p.requires_grad = False
  42.        
  43.     def forward(
  44.         self,
  45.         input_ids: torch.LongTensor,
  46.         attention_mask: torch.LongTensor,
  47.     ) -> torch.FloatTensor:
  48.         out = self.model(input_ids, attention_mask)[0][:,0,:]
  49.         out = self.projection(self.norm(F.relu(out)))
  50.         return F.normalize(out, dim=-1)
  51.  
  52.  
  53. class PositionModel(nn.Module):
  54.     def __init__(self, max_position: int) -> None:
  55.         super().__init__()
  56.         pos = torch.rand(max_position) / max_position
  57.         self.register_parameter("bias", nn.Parameter(pos))
  58.        
  59.     def forward(self):
  60.         cum_pos = F.softplus(self.bias).cumsum(dim=-1).flip(dims=[-1])
  61.         max_bias = cum_pos[0]
  62.         return cum_pos - max_bias
Advertisement
Add Comment
Please, Sign In to add comment