Advertisement
Guest User

Untitled

a guest
May 2nd, 2025
22
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.46 KB | Source Code | 0 0
  1. import os
  2. import torch
  3. from pytorch_lightning import LightningDataModule
  4. from torch.utils.data import TensorDataset, random_split, DataLoader
  5.  
  6. num_workers = os.cpu_count()
  7.  
  8. class DataModule(LightningDataModule):
  9.     def __init__(
  10.         self,
  11.         X_train: torch.Tensor,
  12.         X_test: torch.Tensor,
  13.         y_train: torch.Tensor,
  14.         y_test: torch.Tensor,
  15.         batch_size: int = 64
  16.     ):
  17.         super().__init__()
  18.         self.X_train = X_train
  19.         self.y_train = y_train
  20.        
  21.         self.X_test = X_test
  22.         self.y_test = y_test
  23.        
  24.         self.train_set = None
  25.         self.test_set = TensorDataset(X_test, y_test)
  26.         self.val_set = None
  27.        
  28.         self.batch_size = batch_size
  29.        
  30.     def setup(self, stage: str = None) -> None:
  31.         if stage == 'fit' or stage is None:
  32.             dataset = TensorDataset(self.X_train, self.y_train)
  33.            
  34.             train_set, val_set = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
  35.             self.train_set = train_set
  36.             self.val_set = val_set
  37.  
  38.     def train_dataloader(self):
  39.         return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
  40.    
  41.     def val_dataloader(self):
  42.         return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
  43.    
  44.     def test_dataloader(self):
  45.         return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement