Guest User

Untitled

a guest
Nov 21st, 2017
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.93 KB | None | 0 0
  1. class MatchLSTM(nn.Module):
  2. def __init__(self, args):
  3. super(MatchLSTM, self).__init__()
  4. self.embd_size = args.embd_size
  5. d = self.embd_size
  6. self.answer_token_len = args.answer_token_len
  7.  
  8. self.embd = WordEmbedding(args)
  9. self.ctx_rnn = nn.GRU(d, d, dropout = 0.2)
  10. self.query_rnn = nn.GRU(d, d, dropout = 0.2)
  11.  
  12. self.ptr_net = PointerNetwork(d, d, self.answer_token_len) # TBD
  13.  
  14. self.w = nn.Parameter(torch.rand(1, d, 1).type(torch.FloatTensor), requires_grad=True) # (1, 1, d)
  15. self.Wq = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d)
  16. self.Wp = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d)
  17. self.Wr = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d)
  18.  
  19. self.match_lstm_cell = nn.LSTMCell(2*d, d)
  20.  
  21. def forward(self, context, query):
  22. # params
  23. d = self.embd_size
  24. bs = context.size(0) # batch size
  25. T = context.size(1) # context length
  26. J = query.size(1) # query length
  27.  
  28. # LSTM Preprocessing Layer
  29. shape = (bs, T, J, d)
  30. embd_context = self.embd(context) # (N, T, d)
  31. embd_context, _h = self.ctx_rnn(embd_context) # (N, T, d)
  32. embd_context_ex = embd_context.unsqueeze(2).expand(shape).contiguous() # (N, T, J, d)
  33. embd_query = self.embd(query) # (N, J, d)
  34. embd_query, _h = self.query_rnn(embd_query) # (N, J, d)
  35. embd_query_ex = embd_query.unsqueeze(1).expand(shape).contiguous() # (N, T, J, d)
  36.  
  37. # Match-LSTM layer
  38. G = to_var(torch.zeros(bs, T, J, d)) # (N, T, J, d)
  39.  
  40. wh_q = torch.bmm(embd_query, self.Wq.expand(bs, d, d)) # (N, J, d) = (N, J, d)(N, d, d)
  41.  
  42. hidden = to_var(torch.randn([bs, d])) # (N, d)
  43. cell_state = to_var(torch.randn([bs, d])) # (N, d)
  44. # TODO bidirectional
  45. H_r = [hidden]
  46. for i in range(T):
  47. wh_p_i = torch.bmm(embd_context[:,i,:].clone().unsqueeze(1), self.Wp.expand(bs, d, d)).squeeze() # (N, 1, d) -> (N, d)
  48. wh_r_i = torch.bmm(hidden.unsqueeze(1), self.Wr.expand(bs, d, d)).squeeze() # (N, 1, d) -> (N, d)
  49. sec_elm = (wh_p_i + wh_r_i).unsqueeze(1).expand(bs, J, d) # (N, J, d)
  50.  
  51. G[:,i,:,:] = F.tanh( (wh_q + sec_elm).view(-1, d) ).view(bs, J, d) # (N, J, d) # TODO bias
  52.  
  53. attn_i = torch.bmm(G[:,i,:,:].clone(), self.w.expand(bs, d, 1)).squeeze() # (N, J)
  54. attn_query = torch.bmm(attn_i.unsqueeze(1), embd_query).squeeze() # (N, d)
  55. z = torch.cat((embd_context[:,i,:], attn_query), 1) # (N, 2d)
  56.  
  57. hidden, cell_state = self.match_lstm_cell(z, (hidden, cell_state)) # (N, d), (N, d)
  58. H_r.append(hidden)
  59. H_r = torch.stack(H_r, dim=1) # (N, T, d)
  60.  
  61. indices = self.ptr_net(H_r) # (N, M, T) , M means (start, end)
  62. return indices
Add Comment
Please, Sign In to add comment