Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import torch
- from pytorch_lightning import LightningDataModule
- from torch.utils.data import TensorDataset, random_split, DataLoader
- num_workers = os.cpu_count()
- class DataModule(LightningDataModule):
- def __init__(
- self,
- X_train: torch.Tensor,
- X_test: torch.Tensor,
- y_train: torch.Tensor,
- y_test: torch.Tensor,
- batch_size: int = 64
- ):
- super().__init__()
- self.X_train = X_train
- self.y_train = y_train
- self.X_test = X_test
- self.y_test = y_test
- self.train_set = None
- self.test_set = TensorDataset(X_test, y_test)
- self.val_set = None
- self.batch_size = batch_size
- def setup(self, stage: str = None) -> None:
- if stage == 'fit' or stage is None:
- dataset = TensorDataset(self.X_train, self.y_train)
- train_set, val_set = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
- self.train_set = train_set
- self.val_set = val_set
- def train_dataloader(self):
- return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
- def val_dataloader(self):
- return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
- def test_dataloader(self):
- return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement