Advertisement
Guest User

Untitled

a guest
Sep 17th, 2019
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.36 KB | None | 0 0
  1. class SiameseNetwork(nn.Module):
  2.     def __init__(self, context_encoder, query_encoder, context_dim, query_dim):
  3.         super(SiameseNetwork, self).__init__()
  4.         self.context_encoder = context_encoder
  5.         self.query_encoder = query_encoder
  6.        
  7.         # siamese network arch
  8.         self.linear_1 = nn.Linear(context_dim + query_dim, 128)
  9.         self.linear_2 = nn.Linear(128, 1)
  10.         self.relu = nn.LeakyReLU()
  11.        
  12.     def forward(self, context, clens, query_pos, qposlens, query_neg=None, qneglens=None, train=True):
  13.         # take both queries while training and only one while testing to assign a score
  14.         # (second input just ignored if train=False)
  15.         context_repr = self.context_encoder(context, clens)
  16.         query_pos_repr = self.query_encoder(query_pos, qposlens)
  17.         siamese_inp_pos = torch.cat([query_pos_repr, context_repr], dim=-1)
  18.         score_pos = self.linear_2(self.linear_1(siamese_inp_pos))
  19.        
  20.         if train:
  21.             assert query_neg is not None, "you have to provide a second input"
  22.             query_neg_repr = self.query_encoder(query_neg, qneglens)
  23.             siamese_inp_neg = torch.cat([query_neg_repr, context_repr], dim=-1)
  24.             score_neg = self.linear_2(self.linear_1(siamese_inp_neg))
  25.             return score_pos - score_neg
  26.        
  27.         else:
  28.             return score_pos
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement