Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class TransformsCoCo:
- '''
- CoCo dataset, for use with 128x128 full image encoder.
- '''
- def __init__(self):
- # image augmentation functions
- self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
- self.multi_crop = transforms.RandomChoice([
- transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
- interpolation=1),
- transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
- interpolation=2),
- transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
- interpolation=3)
- ])
- self.col_jitter = transforms.RandomApply([
- transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8)
- self.rnd_gray = transforms.RandomGrayscale(p=0.25)
- # augmentations for generating "augmented" and "clean" data
- self.train_transform = transforms.Compose([
- self.multi_crop, # + rand crop and interp
- self.col_jitter, # + color jitter
- self.rnd_gray, # + color jitter
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
- ])
- self.test_transform = transforms.Compose([
- transforms.Resize(146, interpolation=3),
- transforms.CenterCrop(128),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
- ])
- def __call__(self, inp):
- inp = self.flip_lr(inp)
- # gather augmented copies of the original (flippy) input
- out1 = self.train_transform(inp)
- out2 = self.train_transform(inp)
- # get a clean copy of the original (flippy) input
- orig = self.test_transform(inp)
- return orig, out1, out2
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement