Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- train_transforms = transforms.Compose([transforms.ToPILImage(),
- # transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225])])
- # Convert Images to 3-channel images by repeating
- x_train = x_train[:batch_size]
- rgb_batch = np.repeat(x_train[..., np.newaxis], 3, -1)
- rgb_batch = rgb_batch.reshape(batch_size, 3, 256, 256)
- train_data = []
- for i in range(batch_size):
- train_data.append(train_transforms(rgb_batch[i]))
- train_data = torch.stack(train_data)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement