PioneerAlexander

Untitled

May 5th, 2025 (edited)
326
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.45 KB | None | 0 0
  1. import torch
  2.  
  3. import torch.nn as nn
  4.  
  5. class MyLSTM(nn.Module):
  6.     def __init__(
  7.         self,
  8.         input_size,
  9.         output_size,
  10.         hidden_size,
  11.         fc_depth_input,
  12.         fc_depth_output,
  13.         n_layers,
  14.         dropout: float = 0.1,
  15.         use_layer_norm: bool = False,
  16.     ):
  17.         super().__init__()
  18.  
  19.         assert n_layers >= 1, f"Expected at least 1 layer for LSTM, but found {n_layers}"
  20.         self.dropout_layer = nn.Dropout(p=dropout)
  21.         self.fc_input = nn.ModuleList(
  22.             [layer for _ in range(fc_depth_input) for layer in (nn.Linear(input_size, input_size), nn.ReLU(), nn.Dropout(p=dropout))]
  23.         )
  24.         self.lstm = nn.ModuleList()
  25.        
  26.         for idx in range(n_layers):
  27.             self.lstm.append(MyLSTMLayer(layer_input_size=input_size if idx == 0 else hidden_size, layer_hidden_size=hidden_size))
  28.  
  29.         self.fc_output = nn.ModuleList(
  30.             [layer for _ in range(fc_depth_output) for layer in (nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(p=dropout))]
  31.         )
  32.         self.fc_last = nn.Linear(hidden_size, output_size)
  33.  
  34.     def forward(
  35.         self,
  36.         x,
  37.     ):
  38.         for layer in self.fc_input:
  39.             x = layer(x)
  40.  
  41.         for lstm_layer in self.lstm:
  42.             x = lstm_layer(x)
  43.             x = self.dropout_layer(x)
  44.  
  45.         for layer in self.fc_output:
  46.             x = layer(x)
  47.         x = self.fc_last(x)
  48.  
  49.         return x
  50.  
  51. class MyLSTMLayer(nn.Module):
  52.     def __init__(
  53.         self,
  54.         layer_input_size,
  55.         layer_hidden_size,
  56.     ):
  57.         super().__init__()
  58.  
  59.         self.layer_input_size = layer_input_size
  60.         self.layer_hidden_size = layer_hidden_size
  61.  
  62.         self.input_weights = nn.Linear(in_features=layer_input_size, out_features=layer_hidden_size * 4, bias=True)
  63.         self.hidden_weights = nn.Linear(in_features=layer_hidden_size, out_features=layer_hidden_size * 4, bias=True)
  64.         self.tanh = torch.tanh
  65.         self.sigmoid = torch.sigmoid
  66.  
  67.     def forward(
  68.         self,
  69.         x,
  70.     ):
  71.         c_t = torch.zeros((
  72.             x.shape[0],
  73.             1,
  74.             self.layer_hidden_size,
  75.         )).to(x.device)
  76.         h_t = torch.zeros((
  77.             x.shape[0],
  78.             1,
  79.             self.layer_hidden_size,
  80.         )).to(x.device)
  81.        
  82.         output_inputs = self.input_weights(x)
  83.        
  84.         out = torch.empty((
  85.             x.shape[0],
  86.             x.shape[1],
  87.             self.layer_hidden_size,
  88.         ), device=x.device)
  89.         for t in range(x.shape[1]): # iterate over seq_len
  90.             output_inputs_t = output_inputs[:, t, :].unsqueeze(1)
  91.             output_hiddens_t = self.hidden_weights(h_t)
  92.  
  93.             gates_inputs = output_inputs_t + output_hiddens_t
  94.  
  95.             input_gate = self.sigmoid(gates_inputs[:, :, :self.layer_hidden_size])
  96.             forget_gate = self.sigmoid(gates_inputs[:, :, self.layer_hidden_size:2 * self.layer_hidden_size])
  97.             cell_gate = self.tanh(gates_inputs[:, :, self.layer_hidden_size * 2:self.layer_hidden_size * 3])
  98.             output_gate = self.sigmoid(gates_inputs[:, :, self.layer_hidden_size * 3:self.layer_hidden_size * 4])
  99.  
  100.             c_t = c_t * forget_gate # forget information
  101.            
  102.             c_t += input_gate * cell_gate # add new information
  103.             h_t = output_gate * self.tanh(c_t) # output information
  104.  
  105.             out[:, t, :] = h_t.squeeze(1)
  106.  
  107.         return out
  108.  
Advertisement
Add Comment
Please, Sign In to add comment