Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class model(nn.Module):
- def __init__(self):
- super().__init__()
- self.seqLength = 5
- self.d_model = 3
- self.nhead = 1
- self.num_encoder_layers = 6
- self.num_decoder_layers = 6
- self.dim_feedforward = 1024
- self.dropout = 0.1
- self.activation = F.relu
- self.batch_first = True
- self.transformer = nn.Transformer(
- d_model = self.d_model,
- nhead = self.nhead,
- num_encoder_layers = self.num_encoder_layers,
- num_decoder_layers = self.num_decoder_layers,
- dim_feedforward = self.dim_feedforward,
- dropout = self.dropout,
- activation = self.activation,
- batch_first = self.batch_first
- )
- def forward(self, encoderSeq, decoderSeq, decoderPaddingMask=None, decoderMask=None):
- y_hat = self.transformer(src=encoderSeq, tgt=decoderSeq, tgt_mask=decoderMask, tgt_key_padding_mask=decoderPaddingMask)
- return y_hat
- def generateDecoderMask(size, n):
- output = torch.zeros((size, size))
- if not n:
- return output
- output[:, -n:] = -torch.inf
- output[-n:, :-n] = -torch.inf
- return output
- model = model()
- model.eval()
- encSeq = torch.rand((1, model.seqLength, model.d_model))
- decSeq = torch.rand((1, model.seqLength, model.d_model))
- # Try with reverse L mask
- decoderMask = generateDecoderMask(model.seqLength, 1)
- output = model(encSeq, decSeq, decoderMask=decoderMask)
- print(decoderMask)
- print(output, "\n\n\n")
- # Try using tgt_key_padding_mask
- # √
- decoderMask = torch.tensor([0, 0, 0, 0, 1]).unsqueeze(0).bool()
- output = model(encSeq, decSeq, decoderPaddingMask=decoderMask)
- print(decoderMask)
- print(output, "\n\n\n")
- # Try with only columns with -inf
- # √
- decoderMask = torch.zeros((model.seqLength, model.seqLength))
- decoderMask[:, -1] = -torch.inf
- output = model(encSeq, decSeq, decoderMask=decoderMask)
- print(decoderMask)
- print(output, "\n\n\n")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement