Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from collections import OrderedDict
- import torch
- from torch import nn
- import pytorch_lightning as pl
- class FCNNRegressor(pl.LightningModule):
- def __init__(
- self,
- input_dim: int,
- hidden_dim1: int,
- hidden_dim2: int,
- dropout_prob: float,
- lr: float = 1e-3,
- weight_decay: float = 1e-4
- ):
- super().__init__()
- self.model = nn.Sequential(OrderedDict([
- ('fc1', nn.Linear(input_dim, hidden_dim1)),
- ('bn1', nn.BatchNorm1d(hidden_dim1)),
- ('relu1', nn.ReLU()),
- ('dropout1', nn.Dropout(p=dropout_prob)),
- ('fc2', nn.Linear(hidden_dim1, hidden_dim2)),
- ('bn2', nn.BatchNorm1d(hidden_dim2)),
- ('relu2', nn.ReLU()),
- ('dropout2', nn.Dropout(p=dropout_prob)),
- ('out', nn.Linear(hidden_dim2, 1))
- ]))
- self.lr = lr
- self.weight_decay = weight_decay
- self.criterion = nn.MSELoss()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.model(x)
- def training_step(self, batch: torch.Tensor, batch_idx: int) -> float:
- x, y = batch
- preds = self(x)
- loss = self.criterion(preds, y)
- self.log('train_loss', loss)
- return loss
- def validation_step(self, batch: torch.Tensor, batch_idx: int) -> float:
- x, y = batch
- preds = self(x)
- loss = self.criterion(preds, y)
- self.log('val_loss', loss)
- return loss
- def configure_optimizers(self) -> None:
- decay = []
- no_decay = []
- for name, param in self.named_parameters():
- if param.requires_grad:
- if 'bias' in name or 'bn' in name:
- no_decay.append(param)
- else:
- decay.append(param)
- optimizer = torch.optim.Adam([
- {'params': decay, 'weight_decay': self.weight_decay},
- {'params': no_decay, 'weight_decay': 0}
- ], lr=self.lr)
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- optimizer,
- mode='min',
- factor=0.5,
- patience=2
- )
- return {
- 'optimizer': optimizer,
- 'lr_scheduler': {
- 'scheduler': scheduler,
- 'monitor': 'val_loss',
- 'interval': 'epoch',
- 'frequency': 1
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement