Advertisement
Guest User

Untitled

a guest
Jan 4th, 2022
46
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.08 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4.  
  5. class model(nn.Module):
  6.     def __init__(self):
  7.         super().__init__()
  8.         self.seqLength = 5
  9.         self.d_model = 3
  10.         self.nhead = 1
  11.         self.num_encoder_layers = 6
  12.         self.num_decoder_layers = 6
  13.         self.dim_feedforward = 1024
  14.         self.dropout = 0.1
  15.         self.activation = F.relu
  16.         self.batch_first = True
  17.  
  18.         self.transformer = nn.Transformer(
  19.                 d_model = self.d_model,
  20.                 nhead = self.nhead,
  21.                 num_encoder_layers = self.num_encoder_layers,
  22.                 num_decoder_layers = self.num_decoder_layers,
  23.                 dim_feedforward = self.dim_feedforward,
  24.                 dropout = self.dropout,
  25.                 activation = self.activation,
  26.                 batch_first = self.batch_first
  27.                 )
  28.     def forward(self, encoderSeq, decoderSeq, decoderPaddingMask=None, decoderMask=None):
  29.         y_hat = self.transformer(src=encoderSeq, tgt=decoderSeq, tgt_mask=decoderMask, tgt_key_padding_mask=decoderPaddingMask)
  30.         return y_hat
  31.  
  32. def generateDecoderMask(size, n):
  33.     output = torch.zeros((size, size))
  34.     if not n:
  35.         return output
  36.     output[:, -n:] = -torch.inf
  37.     output[-n:, :-n] = -torch.inf
  38.     return output
  39.  
  40. model = model()
  41. model.eval()
  42.  
  43. encSeq = torch.rand((1, model.seqLength, model.d_model))
  44. decSeq = torch.rand((1, model.seqLength, model.d_model))
  45.  
  46. # Try with reverse L mask
  47. decoderMask = generateDecoderMask(model.seqLength, 1)
  48. output = model(encSeq, decSeq, decoderMask=decoderMask)
  49. print(decoderMask)
  50. print(output, "\n\n\n")
  51.  
  52. # Try using tgt_key_padding_mask
  53. # √
  54. decoderMask = torch.tensor([0, 0, 0, 0, 1]).unsqueeze(0).bool()
  55. output = model(encSeq, decSeq, decoderPaddingMask=decoderMask)
  56. print(decoderMask)
  57. print(output, "\n\n\n")
  58.  
  59. # Try with only columns with -inf
  60. # √
  61. decoderMask = torch.zeros((model.seqLength, model.seqLength))
  62. decoderMask[:, -1] = -torch.inf
  63. output = model(encSeq, decSeq, decoderMask=decoderMask)
  64. print(decoderMask)
  65. print(output, "\n\n\n")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement