Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class DocModel(nn.Module):
- def __init__(
- self,
- model: nn.Module,
- max_doc_length: int,
- max_position: int,
- ) -> None:
- super().__init__()
- self.model = model
- self.norm = nn.LayerNorm(DIM)
- self.projection = nn.Linear(DIM, DIM, bias=False)
- self.max_doc_length = max_doc_length
- self.max_position = max_position
- for p in self.model.parameters():
- p.requires_grad = False
- def forward(
- self,
- input_ids: torch.LongTensor,
- attention_mask: torch.LongTensor,
- ) -> torch.FloatTensor:
- batch_size = len(input_ids)
- out = self.model(
- input_ids=input_ids.view(-1, self.max_doc_length),
- attention_mask=attention_mask.view(-1, self.max_doc_length),
- )[0][:, 0, :] # get CLS token
- out = self.projection(self.norm(F.relu(out)))
- out = out.view(batch_size, self.max_position, -1)
- return F.normalize(out, dim=-1)
- class QueryModel(nn.Module):
- def __init__(self, model: nn.Module) -> None:
- super().__init__()
- self.model = model
- self.norm = nn.LayerNorm(DIM)
- self.projection = nn.Linear(DIM, DIM, bias=False)
- for p in self.model.parameters():
- p.requires_grad = False
- def forward(
- self,
- input_ids: torch.LongTensor,
- attention_mask: torch.LongTensor,
- ) -> torch.FloatTensor:
- out = self.model(input_ids, attention_mask)[0][:,0,:]
- out = self.projection(self.norm(F.relu(out)))
- return F.normalize(out, dim=-1)
- class PositionModel(nn.Module):
- def __init__(self, max_position: int) -> None:
- super().__init__()
- pos = torch.rand(max_position) / max_position
- self.register_parameter("bias", nn.Parameter(pos))
- def forward(self):
- cum_pos = F.softplus(self.bias).cumsum(dim=-1).flip(dims=[-1])
- max_bias = cum_pos[0]
- return cum_pos - max_bias
Advertisement
Add Comment
Please, Sign In to add comment