Advertisement
Guest User

Untitled

a guest
Apr 6th, 2020
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.28 KB | None | 0 0
  1. class Decoder(nn.Module):
  2.     def __init__(self, tgt_vocab_size, n_blocks, n_features, n_heads, n_hidden=64, dropout=0.1):
  3.         """
  4.        Args:
  5.          tgt_vocab_size: Number of words in the target vocabulary.
  6.          n_blocks: Number of EncoderBlock blocks.
  7.          n_features: Number of features to be used for word embedding and further in all layers of the decoder.
  8.          n_heads: Number of attention heads inside the DecoderBlock.
  9.          n_hidden: Number of hidden units in the Feedforward block of DecoderBlock.
  10.          dropout: Dropout level used in DecoderBlock.
  11.        """
  12.         super(Decoder, self).__init__()
  13.         self.decoder   = nn.ModuleList( [DecoderBlock(n_features, n_heads, n_hidden=n_hidden, dropout=dropout) for j in range(n_blocks)] )
  14.         self.embedding = torch.nn.Embedding(tgt_vocab_size, n_features)
  15.         self.pos_encoding = tr.PositionalEncoding(n_features, dropout=dropout, max_len=MAX_LENGTH)
  16.         self.linear  = nn.Linear(n_features, tgt_vocab_size)
  17.         self.softmax = nn.LogSoftmax(dim=2)
  18.        
  19.     def forward(self, y, z, src_mask):
  20.         """
  21.        Args:
  22.          y of shape (max_tgt_seq_length, batch_size): Transformed target sequences used as the inputs
  23.              of the block.
  24.          z of shape (max_src_seq_length, batch_size, n_features): Encoded source sequences (outputs of the
  25.              encoder).
  26.          src_mask of shape (batch_size, max_src_seq_length): Boolean tensor indicating which elements of the
  27.             source sequences should be ignored.
  28.        
  29.        Returns:
  30.          out of shape (max_seq_length, batch_size, tgt_vocab_size): Log-softmax probabilities of the words
  31.              in the output sequences.
  32.  
  33.        Notes:
  34.          * All intermediate signals should be of shape (max_seq_length, batch_size, n_features).
  35.          * You need to create and use the subsequent mask in the decoder.
  36.        """
  37.        
  38.         tgt_mask = subsequent_mask(y.size()[0]) # given function in the beginning of the Decoder part
  39.         out = self.pos_encoding(self.embedding(y))
  40.         for i, f in enumerate(self.decoder):
  41.             out = f(out, z, src_mask, tgt_mask)
  42.         out = self.linear(out)
  43.         out = self.softmax(out)
  44.        
  45.         return out
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement