Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- batch_src = batch.src.cuda() if use_cuda else batch.src
- batch_trg, target = create_trg_batch(batch.trg)
- encoder_hidden = encoder.init_hidden(batch.src.size(1))
- encoder_output, decoder_hidden = encoder(batch_src, encoder_hidden)
- embedding = nn.Embedding(trg_vocab_size, embedding_size).cuda()
- lstm = nn.LSTM(embedding_size, hidden_size, n_layers, dropout=dropout).cuda()
- embedded = embedding(batch_trg)
- decoder_output, decoder_hidden = lstm(embedded, decoder_hidden)
- # dot together decoder_output and encoder_output with bmm
- alpha_ij = torch.bmm(encoder_output.transpose(0, 1), decoder_output.permute(1, 2, 0))
- alpha_ij = F.softmax(alpha_ij, dim=1).transpose(1, 2)
- # compute context vectors
- context = torch.bmm(alpha_ij, encoder_output.transpose(0, 1))
- # concatenate context and decoder output
- cat = torch.cat((context.transpose(0, 1), decoder_output), dim=2)
- linear = nn.Linear(2 * hidden_size, trg_vocab_size).cuda()
- decoded = linear(cat.view(cat.size(0) * cat.size(1), cat.size(2)))
- decoded.view(cat.size(0), cat.size(1), decoded.size(1))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement