Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class SiameseNetwork(nn.Module):
- def __init__(self, context_encoder, query_encoder, context_dim, query_dim):
- super(SiameseNetwork, self).__init__()
- self.context_encoder = context_encoder
- self.query_encoder = query_encoder
- # siamese network arch
- self.linear_1 = nn.Linear(context_dim + query_dim, 128)
- self.linear_2 = nn.Linear(128, 1)
- self.relu = nn.LeakyReLU()
- def forward(self, context, clens, query_pos, qposlens, query_neg=None, qneglens=None, train=True):
- # take both queries while training and only one while testing to assign a score
- # (second input just ignored if train=False)
- context_repr = self.context_encoder(context, clens)
- query_pos_repr = self.query_encoder(query_pos, qposlens)
- siamese_inp_pos = torch.cat([query_pos_repr, context_repr], dim=-1)
- score_pos = self.linear_2(self.linear_1(siamese_inp_pos))
- if train:
- assert query_neg is not None, "you have to provide a second input"
- query_neg_repr = self.query_encoder(query_neg, qneglens)
- siamese_inp_neg = torch.cat([query_neg_repr, context_repr], dim=-1)
- score_neg = self.linear_2(self.linear_1(siamese_inp_neg))
- return score_pos - score_neg
- else:
- return score_pos
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement