Advertisement
Guest User

Untitled

a guest
May 2nd, 2025
24
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.51 KB | Source Code | 0 0
  1. from collections import OrderedDict
  2. import torch
  3. from torch import nn
  4. import pytorch_lightning as pl
  5.  
  6. class FCNNRegressor(pl.LightningModule):
  7.     def __init__(
  8.         self,
  9.         input_dim: int,
  10.         hidden_dim1: int,
  11.         hidden_dim2: int,
  12.         dropout_prob: float,
  13.         lr: float = 1e-3,
  14.         weight_decay: float = 1e-4
  15.     ):
  16.         super().__init__()
  17.        
  18.         self.model = nn.Sequential(OrderedDict([
  19.             ('fc1', nn.Linear(input_dim, hidden_dim1)),
  20.             ('bn1', nn.BatchNorm1d(hidden_dim1)),
  21.             ('relu1', nn.ReLU()),
  22.             ('dropout1', nn.Dropout(p=dropout_prob)),
  23.             ('fc2', nn.Linear(hidden_dim1, hidden_dim2)),
  24.             ('bn2', nn.BatchNorm1d(hidden_dim2)),
  25.             ('relu2', nn.ReLU()),
  26.             ('dropout2', nn.Dropout(p=dropout_prob)),
  27.             ('out', nn.Linear(hidden_dim2, 1))
  28.         ]))
  29.        
  30.         self.lr = lr
  31.         self.weight_decay = weight_decay
  32.         self.criterion = nn.MSELoss()
  33.        
  34.     def forward(self, x: torch.Tensor) -> torch.Tensor:
  35.         return self.model(x)
  36.    
  37.     def training_step(self, batch: torch.Tensor, batch_idx: int) -> float:
  38.         x, y = batch
  39.         preds = self(x)
  40.         loss = self.criterion(preds, y)
  41.         self.log('train_loss', loss)
  42.         return loss
  43.    
  44.     def validation_step(self, batch: torch.Tensor, batch_idx: int) -> float:
  45.         x, y = batch
  46.         preds = self(x)
  47.         loss = self.criterion(preds, y)
  48.         self.log('val_loss', loss)
  49.         return loss
  50.    
  51.     def configure_optimizers(self) -> None:
  52.         decay = []
  53.         no_decay = []
  54.        
  55.         for name, param in self.named_parameters():
  56.             if param.requires_grad:
  57.                 if 'bias' in name or 'bn' in name:
  58.                     no_decay.append(param)
  59.                 else:
  60.                     decay.append(param)
  61.  
  62.         optimizer = torch.optim.Adam([
  63.             {'params': decay, 'weight_decay': self.weight_decay},
  64.             {'params': no_decay, 'weight_decay': 0}
  65.         ], lr=self.lr)
  66.  
  67.         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  68.             optimizer,
  69.             mode='min',      
  70.             factor=0.5,      
  71.             patience=2
  72.         )
  73.  
  74.         return {
  75.             'optimizer': optimizer,
  76.             'lr_scheduler': {
  77.                 'scheduler': scheduler,
  78.                 'monitor': 'val_loss',
  79.                 'interval': 'epoch',
  80.                 'frequency': 1
  81.             }
  82.         }
  83.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement