Advertisement
Guest User

Untitled

a guest
Apr 5th, 2020
238
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.79 KB | None | 0 0
  1. class EncoderBlock(nn.Module):
  2.     def __init__(self, n_features, n_heads, n_hidden=64, dropout=0.1):
  3.         """
  4.        Args:
  5.          n_features: Number of input and output features.
  6.          n_heads: Number of attention heads in the Multi-Head Attention.
  7.          n_hidden: Number of hidden units in the Feedforward (MLP) block.
  8.          dropout: Dropout rate after the first layer of the MLP and in two places on the main path (before
  9.                   combining the main path with a skip connection).
  10.        """
  11.         super(EncoderBlock, self).__init__()
  12.         self.Feedforward = nn.ModuleList( [nn.Linear(n_features, n_hidden), nn.Dropout(p=dropout),
  13.                                           nn.ReLU(), nn.Linear(n_hidden, n_features)] )
  14.         self.MultiheadAttention = nn.MultiheadAttention(n_features, n_heads)
  15.         self.Norm1 = nn.LayerNorm( n_features )
  16.         self.Norm2 = nn.LayerNorm( n_features )
  17.         self.Drop1 = nn.Dropout(p=dropout)
  18.         self.Drop2 = nn.Dropout(p=dropout)
  19.  
  20.     def forward(self, x, mask):
  21.         """
  22.        Args:
  23.          x of shape (max_seq_length, batch_size, n_features): Input sequences.
  24.          mask of shape (batch_size, max_seq_length): Boolean tensor indicating which elements of the input
  25.              sequences should be ignored.
  26.        
  27.        Returns:
  28.          z of shape (max_seq_length, batch_size, n_features): Encoded input sequence.
  29.  
  30.        Note: All intermediate signals should be of shape (max_seq_length, batch_size, n_features).
  31.        """
  32.         z, dims = self.MultiheadAttention(x, x, x)
  33.         z = self.Norm1( self.Drop1(z)+x )
  34.         y = z
  35.         for i, f in enumerate(self.Feedforward):
  36.             y = self.Feedforward[i](y)
  37.         z = self.Norm2( self.Drop2(y)+z )
  38.        
  39.         return z
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement