Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def data_loader(data_path, batch_size, num_workers):
- train_dir = os.path.join(data_path, 'train')
- val_dir = os.path.join(data_path, 'val')
- train_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(train_dir, train_transform),
- batch_size=batch_size, shuffle=True,
- num_workers=num_workers,
- )
- val_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(val_dir, eval_transform),
- batch_size=batch_size, shuffle=True,
- num_workers=num_workers,
- )
- return train_loader, val_loader
Add Comment
Please, Sign In to add comment