Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Loads the images for use with the CNN.
- def load_images(image_size=32, batch_size=64, root="../images"):
- transform = transforms.Compose([
- transforms.Resize(32),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- train_set = datasets.ImageFolder(root=root, train=True, transform=transform)
- train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
- return train_loader
- # Defining variables for use with the CNN.
- classes = ('daisy', 'dandelion', 'rose', 'sunflower', 'tulip')
- train_loader_data = load_images()
- # Training samples.
- n_training_samples = 3394
- train_sampler = SubsetRandomSampler(np.arange(n_training_samples, dtype=np.int64))
- # Validation samples.
- n_val_samples = 424
- val_sampler = SubsetRandomSampler(np.arange(n_training_samples, n_training_samples + n_val_samples, dtype=np.int64))
- # Test samples.
- n_test_samples = 424
- test_sampler = SubsetRandomSampler(np.arange(n_test_samples, dtype=np.int64))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement