Advertisement
Guest User

Untitled

a guest
Feb 18th, 2020
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.96 KB | None | 0 0
  1. class TransformsCoCo:
  2.     '''
  3.    CoCo dataset, for use with 128x128 full image encoder.
  4.    '''
  5.     def __init__(self):
  6.         # image augmentation functions
  7.         self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
  8.         self.multi_crop = transforms.RandomChoice([
  9.             transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
  10.                                          interpolation=1),
  11.             transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
  12.                                          interpolation=2),
  13.             transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
  14.                                          interpolation=3)
  15.         ])
  16.         self.col_jitter = transforms.RandomApply([
  17.             transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8)
  18.         self.rnd_gray = transforms.RandomGrayscale(p=0.25)
  19.         # augmentations for generating "augmented" and "clean" data
  20.         self.train_transform = transforms.Compose([
  21.             self.multi_crop,  # + rand crop and interp
  22.             self.col_jitter,  # + color jitter
  23.             self.rnd_gray,    # + color jitter
  24.             transforms.ToTensor(),
  25.             transforms.Normalize(mean=[0.485, 0.456, 0.406],
  26.                                  std=[0.229, 0.224, 0.225])
  27.         ])
  28.         self.test_transform = transforms.Compose([
  29.             transforms.Resize(146, interpolation=3),
  30.             transforms.CenterCrop(128),
  31.             transforms.ToTensor(),
  32.             transforms.Normalize(mean=[0.485, 0.456, 0.406],
  33.                                  std=[0.229, 0.224, 0.225])
  34.         ])
  35.  
  36.     def __call__(self, inp):
  37.         inp = self.flip_lr(inp)
  38.         # gather augmented copies of the original (flippy) input
  39.         out1 = self.train_transform(inp)
  40.         out2 = self.train_transform(inp)
  41.         # get a clean copy of the original (flippy) input
  42.         orig = self.test_transform(inp)
  43.         return orig, out1, out2
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement