Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Decoder(nn.Module):
- def __init__(self, tgt_dictionary_size, embed_size, hidden_size):
- """
- Args:
- tgt_dictionary_size: The number of words in the target dictionary.
- embed_size: The number of dimensions in the word embeddings.
- hidden_size: The number of features in the hidden state.
- """
- super(Decoder, self).__init__()
- self.hidden_size = hidden_size
- self.embed_size = embed_size
- self.tgt_dictionary_size = tgt_dictionary_size
- self.embedding = nn.Embedding(tgt_dictionary_size, hidden_size)
- self.gru = nn.GRU(hidden_size,hidden_size)
- self.out = nn.Linear(hidden_size, tgt_dictionary_size)
- self.softmax = nn.LogSoftmax(dim=1)
- self.device = device
- def forward(self, hidden, pad_tgt_seqs=None, teacher_forcing=False):
- """
- Args:
- hidden of shape (1, batch_size, hidden_size): States of the GRU.
- pad_tgt_seqs of shape (max_out_seq_length, batch_size): Tensor of words (word indices) of the
- target sentence. If None, the output sequence is generated by feeding the decoder's outputs
- (teacher_forcing has to be False).
- teacher_forcing (bool): Whether to use teacher forcing or not.
- Returns:
- outputs of shape (max_out_seq_length, batch_size, tgt_dictionary_size): Tensor of log-probabilities
- of words in the target language.
- hidden of shape (1, batch_size, hidden_size): New states of the GRU.
- Note: Do not forget to transfer tensors that you may want to create in this function to the device
- specified by `hidden.device`.
- """
- if pad_tgt_seqs is None:
- assert not teacher_forcing, 'Cannot use teacher forcing without a target sequence.'
- batch_size = hidden.size(1)
- print(hidden.size())
- prev_word = torch.tensor(SOS_token * np.ones((1, batch_size)), device=self.device, dtype=torch.int64)
- #prev_word = torch.tensor([[SOS_token]])
- max_length = pad_tgt_seqs.size(0) if pad_tgt_seqs is not None else MAX_LENGTH
- outputs = torch.zeros([max_length, batch_size, self.tgt_dictionary_size], device=self.device, dtype = torch.float32)
- for t in range(max_length):
- output = self.embedding(prev_word).view(1, batch_size, -1)
- output = F.relu(output)
- output, hidden = self.gru(output, hidden)
- output = self.out(output)
- output = self.softmax(output)
- outputs[t, :, :] = output
- if teacher_forcing:
- # Feed the target as the next input
- prev_word = pad_target_seqs[t]
- else:
- # Use its own predictions as the next input
- _, topi = output.topk(k = 1, dim = 1)
- prev_word = topi.detach() # detach from history as input
- return outputs, hidden
- def initHidden(self):
- return torch.zeros(1, 1, self.hidden_size, device=device)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement