Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- class MyLSTM(nn.Module):
- def __init__(
- self,
- input_size,
- output_size,
- hidden_size,
- fc_depth_input,
- fc_depth_output,
- n_layers,
- dropout: float = 0.1,
- use_layer_norm: bool = False,
- ):
- super().__init__()
- assert n_layers >= 1, f"Expected at least 1 layer for LSTM, but found {n_layers}"
- self.dropout_layer = nn.Dropout(p=dropout)
- self.fc_input = nn.ModuleList(
- [layer for _ in range(fc_depth_input) for layer in (nn.Linear(input_size, input_size), nn.ReLU(), nn.Dropout(p=dropout))]
- )
- self.lstm = nn.ModuleList()
- for idx in range(n_layers):
- self.lstm.append(MyLSTMLayer(layer_input_size=input_size if idx == 0 else hidden_size, layer_hidden_size=hidden_size))
- self.fc_output = nn.ModuleList(
- [layer for _ in range(fc_depth_output) for layer in (nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(p=dropout))]
- )
- self.fc_last = nn.Linear(hidden_size, output_size)
- def forward(
- self,
- x,
- ):
- for layer in self.fc_input:
- x = layer(x)
- for lstm_layer in self.lstm:
- x = lstm_layer(x)
- x = self.dropout_layer(x)
- for layer in self.fc_output:
- x = layer(x)
- x = self.fc_last(x)
- return x
- class MyLSTMLayer(nn.Module):
- def __init__(
- self,
- layer_input_size,
- layer_hidden_size,
- ):
- super().__init__()
- self.layer_input_size = layer_input_size
- self.layer_hidden_size = layer_hidden_size
- self.input_weights = nn.Linear(in_features=layer_input_size, out_features=layer_hidden_size * 4, bias=True)
- self.hidden_weights = nn.Linear(in_features=layer_hidden_size, out_features=layer_hidden_size * 4, bias=True)
- self.tanh = torch.tanh
- self.sigmoid = torch.sigmoid
- def forward(
- self,
- x,
- ):
- c_t = torch.zeros((
- x.shape[0],
- 1,
- self.layer_hidden_size,
- )).to(x.device)
- h_t = torch.zeros((
- x.shape[0],
- 1,
- self.layer_hidden_size,
- )).to(x.device)
- output_inputs = self.input_weights(x)
- out = torch.empty((
- x.shape[0],
- x.shape[1],
- self.layer_hidden_size,
- ), device=x.device)
- for t in range(x.shape[1]): # iterate over seq_len
- output_inputs_t = output_inputs[:, t, :].unsqueeze(1)
- output_hiddens_t = self.hidden_weights(h_t)
- gates_inputs = output_inputs_t + output_hiddens_t
- input_gate = self.sigmoid(gates_inputs[:, :, :self.layer_hidden_size])
- forget_gate = self.sigmoid(gates_inputs[:, :, self.layer_hidden_size:2 * self.layer_hidden_size])
- cell_gate = self.tanh(gates_inputs[:, :, self.layer_hidden_size * 2:self.layer_hidden_size * 3])
- output_gate = self.sigmoid(gates_inputs[:, :, self.layer_hidden_size * 3:self.layer_hidden_size * 4])
- c_t = c_t * forget_gate # forget information
- c_t += input_gate * cell_gate # add new information
- h_t = output_gate * self.tanh(c_t) # output information
- out[:, t, :] = h_t.squeeze(1)
- return out
Advertisement
Add Comment
Please, Sign In to add comment