Guest User

Untitled

a guest
May 23rd, 2018
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.18 KB | None | 0 0
  1. class Seq2SeqAttention(nn.Module):
  2. def __init__(self, vecs_enc, idx2word_enc, em_sz_enc, vecs_dec, idx2word_dec, em_sz_dec,
  3. num_hidden, out_seq_length, num_layers=2, activation=F.tanh, pad_idx=1):
  4. super().__init__()
  5. self.num_hidden = num_hidden
  6. self.out_seq_length = out_seq_length
  7. self.num_layers = num_layers
  8. self.activation = activation
  9. # encoder
  10. self.encoder_embeddings = create_embeddings(vecs_enc, idx2word_enc, em_sz_enc, pad_idx)
  11. self.encoder_dropout_emb = nn.Dropout(0.1)
  12. self.encoder_dropout = nn.Dropout(0.1)
  13. self.encoder_gru = nn.GRU(em_sz_enc, num_hidden, num_layers=num_layers, bidirectional=True)
  14. self.encoder_out = nn.Linear(num_hidden*2, em_sz_dec, bias=False)
  15. # decoder
  16. self.decoder_embeddings = create_embeddings(vecs_dec, idx2word_dec, em_sz_dec, pad_idx)
  17. self.decoder_dropout = nn.Dropout(0.1)
  18. self.decoder_gru = nn.GRU(em_sz_dec, em_sz_dec, num_layers=num_layers)
  19. self.out = nn.Linear(num_hidden, len(idx2word_dec))
  20. self.out.weight.data = self.decoder_embeddings.weight.data
  21. # attention
  22. self.W1 = rand_p(num_hidden*2, em_sz_dec)
  23. self.l2 = nn.Linear(em_sz_dec, em_sz_dec)
  24. self.l3 = nn.Linear(em_sz_dec+num_hidden*2, em_sz_dec)
  25. self.V = rand_p(em_sz_dec)
  26.  
  27.  
  28. def forward(self, X, y=None, tf_ratio=0.0, return_attention=False):
  29. # encode forward
  30. seq_len, batch_size = X.size()
  31. hidden = self.initHidden(batch_size)
  32. enc_embs = self.encoder_dropout_emb(self.encoder_embeddings(X))
  33. enc_out, hidden = self.encoder_gru(enc_embs, hidden)
  34. hidden = hidden.view(2, 2, batch_size, -1).permute(0, 2, 1, 3).contiguous().view(2, batch_size, -1)
  35. hidden = self.encoder_out(self.encoder_dropout(hidden))
  36. # decode forward
  37. dec_input = Variable(torch.zeros(batch_size).long()).cuda()
  38. w1e = enc_out @ self.W1
  39. results = []
  40. attentions = []
  41. for i in range(self.out_seq_length):
  42. w2d = self.l2(hidden[-1])
  43. u = self.activation(w1e + w2d)
  44. a = F.softmax(u @ self.V, dim=0)
  45. attentions.append(a)
  46. Xa = (a.unsqueeze(2) * enc_out).sum(0)
  47. dec_embs = self.decoder_embeddings(dec_input)
  48. weight_enc = self.l3(torch.cat([dec_embs, Xa], dim=1))
  49. outp, hidden = self.decoder_gru(weight_enc.unsqueeze(0), hidden)
  50. outp = self.out(self.decoder_dropout(outp[0]))
  51. results.append(outp)
  52. # teacher forcing
  53. dec_input = Variable(outp.data.max(1)[1]).cuda()
  54. if (dec_input == 1).all():
  55. break
  56. if (y is not None) and (random.random() < tf_ratio):
  57. if i >= len(y):
  58. break
  59. # assign next value to decoder input
  60. dec_input = y[i]
  61. if return_attention:
  62. return torch.stack(results), torch.stack(attentions)
  63. else:
  64. return torch.stack(results)
  65.  
  66.  
  67. def initHidden(self, batch_size):
  68. return Variable(torch.zeros(self.num_layers*2, batch_size, self.num_hidden)).cuda()
Add Comment
Please, Sign In to add comment