Advertisement
Guest User

Untitled

a guest
Feb 23rd, 2018
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.05 KB | None | 0 0
  1. batch_src = batch.src.cuda() if use_cuda else batch.src
  2. batch_trg, target = create_trg_batch(batch.trg)
  3.  
  4. encoder_hidden = encoder.init_hidden(batch.src.size(1))
  5.  
  6. encoder_output, decoder_hidden = encoder(batch_src, encoder_hidden)
  7.  
  8. embedding = nn.Embedding(trg_vocab_size, embedding_size).cuda()
  9. lstm = nn.LSTM(embedding_size, hidden_size, n_layers, dropout=dropout).cuda()
  10.  
  11. embedded = embedding(batch_trg)
  12. decoder_output, decoder_hidden = lstm(embedded, decoder_hidden)
  13.  
  14. # dot together decoder_output and encoder_output with bmm
  15. alpha_ij = torch.bmm(encoder_output.transpose(0, 1), decoder_output.permute(1, 2, 0))
  16. alpha_ij = F.softmax(alpha_ij, dim=1).transpose(1, 2)
  17.  
  18. # compute context vectors
  19. context = torch.bmm(alpha_ij, encoder_output.transpose(0, 1))
  20.  
  21. # concatenate context and decoder output
  22. cat = torch.cat((context.transpose(0, 1), decoder_output), dim=2)
  23.  
  24. linear = nn.Linear(2 * hidden_size, trg_vocab_size).cuda()
  25.  
  26. decoded = linear(cat.view(cat.size(0) * cat.size(1), cat.size(2)))
  27. decoded.view(cat.size(0), cat.size(1), decoded.size(1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement