Advertisement
Guest User

load-data

a guest
Jul 28th, 2018
965
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.04 KB | None | 0 0
  1. # Loads the images for use with the CNN.
  2. def load_images(image_size=32, batch_size=64, root="../images"):
  3.     transform = transforms.Compose([
  4.         transforms.Resize(32),
  5.         transforms.ToTensor(),
  6.         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  7.  
  8.     train_set = datasets.ImageFolder(root=root, train=True, transform=transform)
  9.     train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
  10.  
  11.     return train_loader
  12.  
  13.  
  14. # Defining variables for use with the CNN.
  15. classes = ('daisy', 'dandelion', 'rose', 'sunflower', 'tulip')
  16. train_loader_data = load_images()
  17.  
  18. # Training samples.
  19. n_training_samples = 3394
  20. train_sampler = SubsetRandomSampler(np.arange(n_training_samples, dtype=np.int64))
  21.  
  22. # Validation samples.
  23. n_val_samples = 424
  24. val_sampler = SubsetRandomSampler(np.arange(n_training_samples, n_training_samples + n_val_samples, dtype=np.int64))
  25.  
  26. # Test samples.
  27. n_test_samples = 424
  28. test_sampler = SubsetRandomSampler(np.arange(n_test_samples, dtype=np.int64))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement