Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Decoder(nn.Module):
- def __init__(self, tgt_vocab_size, n_blocks, n_features, n_heads, n_hidden=64, dropout=0.1):
- """
- Args:
- tgt_vocab_size: Number of words in the target vocabulary.
- n_blocks: Number of EncoderBlock blocks.
- n_features: Number of features to be used for word embedding and further in all layers of the decoder.
- n_heads: Number of attention heads inside the DecoderBlock.
- n_hidden: Number of hidden units in the Feedforward block of DecoderBlock.
- dropout: Dropout level used in DecoderBlock.
- """
- super(Decoder, self).__init__()
- self.decoder = nn.ModuleList( [DecoderBlock(n_features, n_heads, n_hidden=n_hidden, dropout=dropout) for j in range(n_blocks)] )
- self.embedding = torch.nn.Embedding(tgt_vocab_size, n_features)
- self.pos_encoding = tr.PositionalEncoding(n_features, dropout=dropout, max_len=MAX_LENGTH)
- self.linear = nn.Linear(n_features, tgt_vocab_size)
- self.softmax = nn.LogSoftmax(dim=2)
- def forward(self, y, z, src_mask):
- """
- Args:
- y of shape (max_tgt_seq_length, batch_size): Transformed target sequences used as the inputs
- of the block.
- z of shape (max_src_seq_length, batch_size, n_features): Encoded source sequences (outputs of the
- encoder).
- src_mask of shape (batch_size, max_src_seq_length): Boolean tensor indicating which elements of the
- source sequences should be ignored.
- Returns:
- out of shape (max_seq_length, batch_size, tgt_vocab_size): Log-softmax probabilities of the words
- in the output sequences.
- Notes:
- * All intermediate signals should be of shape (max_seq_length, batch_size, n_features).
- * You need to create and use the subsequent mask in the decoder.
- """
- tgt_mask = subsequent_mask(y.size()[0]) # given function in the beginning of the Decoder part
- out = self.pos_encoding(self.embedding(y))
- for i, f in enumerate(self.decoder):
- out = f(out, z, src_mask, tgt_mask)
- out = self.linear(out)
- out = self.softmax(out)
- return out
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement