Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from torch.quantization import quantize_dynamic
- from torchao.quantization import (
- quantize_,
- Int8DynamicActivationInt8WeightConfig
- )
- import torch
- from tqdm import tqdm
- import torch.nn as nn
- from torch.optim import Adam
- class MyLSTM(nn.Module):
- def __init__(
- self,
- input_size,
- output_size,
- hidden_size,
- n_layers,
- dropout: float = 0.1,
- use_layer_norm: bool = False,
- ):
- super().__init__()
- assert n_layers >= 1, f"Expected at least 1 layer for LSTM, but found {n_layers}"
- self.dropout_layer = nn.Dropout(p=dropout)
- self.lstm = nn.ModuleList()
- for idx in range(n_layers):
- self.lstm.append(MyLSTMLayer(layer_input_size=input_size if idx == 0 else hidden_size, layer_hidden_size=hidden_size))
- self.fc_last = nn.Linear(hidden_size, output_size)
- def forward(
- self,
- x,
- ):
- for lstm_layer in self.lstm:
- x = lstm_layer(x)
- x = self.dropout_layer(x)
- x = self.fc_last(x)
- return x
- class MyLSTMLayer(nn.Module):
- def __init__(
- self,
- layer_input_size,
- layer_hidden_size,
- ):
- super().__init__()
- self.layer_input_size = layer_input_size
- self.layer_hidden_size = layer_hidden_size
- self.input_weights = nn.Linear(in_features=layer_input_size, out_features=layer_hidden_size * 4, bias=True)
- self.hidden_weights = nn.Linear(in_features=layer_hidden_size, out_features=layer_hidden_size * 4, bias=True)
- self.tanh = torch.tanh
- self.sigmoid = torch.sigmoid
- def forward(
- self,
- x,
- ):
- c_t = torch.zeros((
- x.shape[0],
- 1,
- self.layer_hidden_size,
- )).to(x.device)
- h_t = torch.zeros((
- x.shape[0],
- 1,
- self.layer_hidden_size,
- )).to(x.device)
- output_inputs = self.input_weights(x)
- out = torch.empty((
- x.shape[0],
- x.shape[1],
- self.layer_hidden_size,
- ), device=x.device)
- for t in range(x.shape[1]): # iterate over seq_len
- output_inputs_t = output_inputs[:, t, :].unsqueeze(1)
- output_hiddens_t = self.hidden_weights(h_t)
- gates_inputs = output_inputs_t + output_hiddens_t
- input_gate = self.sigmoid(gates_inputs[:, :, :self.layer_hidden_size])
- forget_gate = self.sigmoid(gates_inputs[:, :, self.layer_hidden_size:2 * self.layer_hidden_size])
- cell_gate = self.tanh(gates_inputs[:, :, self.layer_hidden_size * 2:self.layer_hidden_size * 3])
- output_gate = self.sigmoid(gates_inputs[:, :, self.layer_hidden_size * 3:self.layer_hidden_size * 4])
- c_t = c_t * forget_gate # forget information
- c_t += input_gate * cell_gate # add new information
- h_t = output_gate * self.tanh(c_t) # output information
- out[:, t, :] = h_t.squeeze(1)
- return out
- class LSTM(nn.Module):
- def __init__(
- self,
- input_size,
- output_size,
- hidden_size,
- n_layers,
- dropout,
- use_layer_norm: bool = False,
- batch_first=True,
- ):
- """
- Args:
- input_size (int): Number of input features per time step.
- hidden_size (int): Number of features in the hidden state.
- num_layers (int): Number of recurrent layers.
- dropout (float): Dropout probability between LSTM layers.
- """
- super(LSTM, self).__init__()
- self.hidden_size = hidden_size
- self.num_layers = n_layers
- self.lstm = nn.LSTM(
- input_size=input_size,
- hidden_size=hidden_size,
- num_layers=n_layers,
- batch_first=batch_first, # Input and output tensors are provided as (batch_size, seq_len, feature_dim)
- dropout=0.1,
- )
- self.dropout_layer = nn.Dropout(p=dropout)
- self.fc_last = nn.Linear(hidden_size, output_size)
- self.init_weights()
- def init_weights(self):
- nn.init.xavier_uniform_(self.fc_last.weight)
- nn.init.constant_(self.fc_last.bias, 0)
- def forward(self, x):
- """
- Args:
- x (torch.Tensor): Input tensor of shape (batch_size, seq_length, input_size)
- Returns:
- torch.Tensor: Output tensor of shape (batch_size, 1)
- """
- # Initialize hidden and cell states
- h0 = torch.zeros(
- self.num_layers,
- x.size(0),
- self.hidden_size,
- device=x.device
- )
- c0 = torch.zeros(
- self.num_layers,
- x.size(0),
- self.hidden_size,
- device=x.device
- )
- out, _ = self.lstm(x, (h0, c0)) # out: (batch, seq, hidden_size)
- out = self.dropout_layer(out) # dropout after lstm
- out = self.fc_last(out)
- return out
- n_iters = 1_000
- batch_size = 128
- seq_len = 500
- device = 'cuda' if torch.cuda.is_available else 'cpu'
- model_config = {
- "input_size": 1,
- "output_size": 1,
- "hidden_size": 256,
- "n_layers": 2,
- "dropout": 0.1
- }
- model = LSTM(
- **model_config
- )
- # print(model)
- optimizer = Adam(model.parameters(), lr=1e-3)
- model.train()
- model = model.to(device)
- total_loss = 0
- for idx in tqdm(range(n_iters), total=n_iters):
- x = (torch.rand((batch_size, seq_len, 1), device=device) * 2 - 1) * torch.pi # random angle from -pi to pi
- eps = torch.randn((batch_size, seq_len, 1), device=device, requires_grad=False) / 10
- y = torch.sin(x)
- optimizer.zero_grad()
- # forward
- output = model(x + eps)
- loss = (output - y).pow(2).mean()
- loss.backward()
- optimizer.step()
- total_loss += loss.detach()
- # uncomment if you want to observe convergence
- # if (idx + 1) % 10 == 0:
- # print("Loss:", total_loss / 10)
- # total_loss = 0
- model.eval()
- model = model.to(device)
- with torch.no_grad():
- test_input = (torch.rand((1000, 500, 1), device=device) * 2 - 1) * torch.pi
- test_target = torch.sin(test_input)
- test_pred = model(test_input)
- print("Baseline solution (no quantization):")
- print("MSE: {mse:.5f}".format(mse=(test_pred - test_target).pow(2).mean().item()), "MAE: {mae:.5f}".format(mae=torch.abs(test_pred - test_target).mean().item()))
- q_model = quantize_dynamic(
- model.to('cpu'),
- {nn.Linear, nn.LSTM},
- torch.qint8,
- )
- # print(q_model)
- q_model.eval()
- q_pred = q_model(test_input.cpu())
- print("Quantization of the trained model using torch.quantization:")
- print("MSE: {mse:.5f}".format(mse=(q_pred - test_target.cpu()).pow(2).mean().item()), "MAE: {mae:.5f}".format(mae=torch.abs(q_pred - test_target.cpu()).mean().item()))
- q_model_torchao = MyLSTM(
- **model_config
- )
- names = [name for name, _ in q_model_torchao.named_parameters()]
- my_lstm_model_state_dict = {}
- model_state_dict = model.state_dict()
- for key in names:
- if "weights" not in key:
- my_lstm_model_state_dict[key] = model_state_dict[key]
- else:
- splitted_key = key.split(".")
- layer_indicator = "i" if splitted_key[2].startswith("input") else "h"
- updated_key = f"lstm.{splitted_key[-1]}_{layer_indicator}h_l{splitted_key[1]}"
- my_lstm_model_state_dict[key] = model_state_dict[updated_key]
- q_model_torchao.load_state_dict(my_lstm_model_state_dict)
- quantize_(
- q_model_torchao,
- Int8DynamicActivationInt8WeightConfig(),
- # device='cpu', uncomment for the comparison on the same device
- )
- q_model_torchao = q_model_torchao.to(device)
- q_model_torchao.eval()
- # if device of `q_model_torchao` is cpu, move test_input and test_target to cpu
- with torch.no_grad():
- q_pred_torchao = q_model_torchao(test_input)
- print("Quantization of the trained model using torchao:")
- print("MSE: {mse:.5f}".format(mse=(q_pred_torchao - test_target).pow(2).mean().item()), "MAE: {mae:.5f}".format(mae=torch.abs(q_pred_torchao - test_target).mean().item()))
Advertisement