Advertisement
Guest User

Untitled

a guest
Mar 29th, 2020
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.07 KB | None | 0 0
  1. class Decoder(nn.Module):
  2. def __init__(self, tgt_dictionary_size, embed_size, hidden_size):
  3. """
  4. Args:
  5. tgt_dictionary_size: The number of words in the target dictionary.
  6. embed_size: The number of dimensions in the word embeddings.
  7. hidden_size: The number of features in the hidden state.
  8. """
  9. super(Decoder, self).__init__()
  10. self.hidden_size = hidden_size
  11. self.embed_size = embed_size
  12. self.tgt_dictionary_size = tgt_dictionary_size
  13.  
  14. self.embedding = nn.Embedding(tgt_dictionary_size, hidden_size)
  15. self.gru = nn.GRU(hidden_size,hidden_size)
  16. self.out = nn.Linear(hidden_size, tgt_dictionary_size)
  17. self.softmax = nn.LogSoftmax(dim=1)
  18. self.device = device
  19.  
  20.  
  21. def forward(self, hidden, pad_tgt_seqs=None, teacher_forcing=False):
  22. """
  23. Args:
  24. hidden of shape (1, batch_size, hidden_size): States of the GRU.
  25. pad_tgt_seqs of shape (max_out_seq_length, batch_size): Tensor of words (word indices) of the
  26. target sentence. If None, the output sequence is generated by feeding the decoder's outputs
  27. (teacher_forcing has to be False).
  28. teacher_forcing (bool): Whether to use teacher forcing or not.
  29.  
  30. Returns:
  31. outputs of shape (max_out_seq_length, batch_size, tgt_dictionary_size): Tensor of log-probabilities
  32. of words in the target language.
  33. hidden of shape (1, batch_size, hidden_size): New states of the GRU.
  34.  
  35. Note: Do not forget to transfer tensors that you may want to create in this function to the device
  36. specified by `hidden.device`.
  37.  
  38. """
  39.  
  40. if pad_tgt_seqs is None:
  41. assert not teacher_forcing, 'Cannot use teacher forcing without a target sequence.'
  42.  
  43. batch_size = hidden.size(1)
  44. print(hidden.size())
  45. prev_word = torch.tensor(SOS_token * np.ones((1, batch_size)), device=self.device, dtype=torch.int64)
  46. #prev_word = torch.tensor([[SOS_token]])
  47.  
  48. max_length = pad_tgt_seqs.size(0) if pad_tgt_seqs is not None else MAX_LENGTH
  49. outputs = torch.zeros([max_length, batch_size, self.tgt_dictionary_size], device=self.device, dtype = torch.float32)
  50. for t in range(max_length):
  51. output = self.embedding(prev_word).view(1, batch_size, -1)
  52. output = F.relu(output)
  53. output, hidden = self.gru(output, hidden)
  54. output = self.out(output)
  55. output = self.softmax(output)
  56. outputs[t, :, :] = output
  57. if teacher_forcing:
  58. # Feed the target as the next input
  59. prev_word = pad_target_seqs[t]
  60. else:
  61. # Use its own predictions as the next input
  62. _, topi = output.topk(k = 1, dim = 1)
  63. prev_word = topi.detach() # detach from history as input
  64. return outputs, hidden
  65.  
  66.  
  67. def initHidden(self):
  68. return torch.zeros(1, 1, self.hidden_size, device=device)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement