Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class EncoderBlock(nn.Module):
- def __init__(self, n_features, n_heads, n_hidden=64, dropout=0.1):
- """
- Args:
- n_features: Number of input and output features.
- n_heads: Number of attention heads in the Multi-Head Attention.
- n_hidden: Number of hidden units in the Feedforward (MLP) block.
- dropout: Dropout rate after the first layer of the MLP and in two places on the main path (before
- combining the main path with a skip connection).
- """
- super(EncoderBlock, self).__init__()
- self.Feedforward = nn.ModuleList( [nn.Linear(n_features, n_hidden), nn.Dropout(p=dropout),
- nn.ReLU(), nn.Linear(n_hidden, n_features)] )
- self.MultiheadAttention = nn.MultiheadAttention(n_features, n_heads)
- self.Norm1 = nn.LayerNorm( n_features )
- self.Norm2 = nn.LayerNorm( n_features )
- self.Drop1 = nn.Dropout(p=dropout)
- self.Drop2 = nn.Dropout(p=dropout)
- def forward(self, x, mask):
- """
- Args:
- x of shape (max_seq_length, batch_size, n_features): Input sequences.
- mask of shape (batch_size, max_seq_length): Boolean tensor indicating which elements of the input
- sequences should be ignored.
- Returns:
- z of shape (max_seq_length, batch_size, n_features): Encoded input sequence.
- Note: All intermediate signals should be of shape (max_seq_length, batch_size, n_features).
- """
- z, dims = self.MultiheadAttention(x, x, x)
- z = self.Norm1( self.Drop1(z)+x )
- y = z
- for i, f in enumerate(self.Feedforward):
- y = self.Feedforward[i](y)
- z = self.Norm2( self.Drop2(y)+z )
- return z
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement