Advertisement
KDLPro

Custom RNN Architecture

Apr 24th, 2024
467
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.85 KB | None | 0 0
  1. # Define our network class by using the nn.module
  2. class ResBlockMLP(nn.Module):
  3.     def __init__(self, input_size, output_size):
  4.         super(ResBlockMLP, self).__init__()
  5.         self.norm1 = nn.LayerNorm(input_size)
  6.         self.fc1 = nn.Linear(input_size, input_size//2)
  7.        
  8.         self.norm2 = nn.LayerNorm(input_size//2)
  9.         self.fc2 = nn.Linear(input_size//2, output_size)
  10.        
  11.         self.fc3 = nn.Linear(input_size, output_size)
  12.  
  13.         self.act = nn.ELU()
  14.  
  15.     def forward(self, x):
  16.         x = self.act(self.norm1(x))
  17.         skip = self.fc3(x)
  18.        
  19.         x = self.act(self.norm2(self.fc1(x)))
  20.         x = self.fc2(x)
  21.        
  22.         return x + skip
  23.  
  24.  
  25. class RNN(nn.Module):
  26.     def __init__(self, seq_len, output_size, num_blocks=1, buffer_size=128):
  27.         super(RNN, self).__init__()
  28.        
  29.         seq_data_len = seq_len * 2
  30.  
  31.         self.input_mlp = nn.Sequential(nn.Linear(seq_data_len, 4 * seq_data_len),
  32.                                        nn.ELU(),
  33.                                        nn.Linear(4 * seq_data_len, 128),
  34.                                        nn.ELU(),)
  35.        
  36.         self.rnn = nn.Linear(256, 128)
  37.        
  38.         blocks = [ResBlockMLP(128, 128) for _ in range(num_blocks)]
  39.         self.res_blocks = nn.Sequential(*blocks)
  40.        
  41.         self.fc_out = nn.Linear(128, output_size)
  42.         self.fc_buffer = nn.Linear(128, buffer_size)
  43.         self.act = nn.ELU()
  44.  
  45.  
  46.     def forward(self, input_seq, buffer_in):
  47.         input_seq = input_seq.reshape(input_seq.shape[0], -1)
  48.         input_vec = self.input_mlp(input_seq)
  49.        
  50.         # Concatenate the previous step buffer
  51.         x_cat = torch.cat((buffer_in, input_vec), 1)
  52.         x = self.rnn(x_cat)
  53.  
  54.         x  = self.act(self.res_blocks(x))
  55.        
  56.         return self.fc_out(x), torch.tanh(self.fc_buffer(x))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement