Bunich

AML ch2 DS class

May 19th, 2021
38
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class RefugeDataset(Dataset):
  2.  
  3. def __init__(self, root_dir, split='train', output_size=(256,256), augment=False):
  4. # Define attributes
  5. self.output_size = output_size
  6. self.root_dir = root_dir
  7. self.split = split
  8. self.labels = []
  9.  
  10. # Transforms
  11. trans_img = transforms.Compose([
  12. transforms.ToTensor(),
  13. transforms.Resize(self.output_size, interpolation=Image.BILINEAR),
  14. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  15. ])
  16.  
  17. # Load data index
  18. with open(os.path.join(self.root_dir, self.split, 'index.json')) as f:
  19. self.index = json.load(f)
  20.  
  21. self.images = []
  22. for k in range(len(self.index)):
  23. print('Loading {} image {}/{}...'.format(split, k, len(self.index)), end='\r')
  24. img_name = os.path.join(self.root_dir, self.split, 'images', self.index[str(k)]['ImgName'])
  25. img = np.array(Image.open(img_name).convert('RGB'))
  26. img = trans_img(img)
  27. self.images.append(img)
  28. #Augmentation for training set :
  29. if augment and split == 'train' :
  30. #if self.index[str(k)]['Label'] == 1 :
  31. # Flip vertically around the y axis
  32. flip_vertival_img = transforms.functional.hflip(img)
  33. self.images.append(flip_vertival_img)
  34. # Flip horizontally around the x axis
  35. flip_horizontal_img = transforms.functional.vflip(img)
  36. self.images.append(flip_horizontal_img)
  37. # Random noise
  38. noisy_img = transforms.functional.gaussian_blur(img,kernel_size= [5,5])
  39. self.images.append(noisy_img)
  40. # Rotation +5°
  41. rot_pos_img = transforms.functional.rotate(img,angle = 5)
  42. self.images.append(rot_pos_img)
  43. # Rotation -5°
  44. rot_neg_img = transforms.functional.rotate(img,angle = -5)
  45. self.images.append(rot_neg_img)
  46.  
  47. # Load ground truth for 'train' and 'val' sets
  48. if split != 'test':
  49. self.segs = []
  50. for k in range(len(self.index)):
  51. curr_label = self.index[str(k)]['Label']
  52. self.labels.append(curr_label)
  53. print('Loading {} segmentation {}/{}...'.format(split, k, len(self.index)), end='\r')
  54. seg_name = os.path.join(self.root_dir, self.split, 'gts', self.index[str(k)]['ImgName'].split('.')[0]+'.bmp')
  55. seg = np.array(Image.open(seg_name)).copy()
  56. seg = 255. - seg
  57. od = (seg>=127.).astype(np.float32)
  58. oc = (seg>=250.).astype(np.float32)
  59. od = torch.from_numpy(od[None,:,:])
  60. oc = torch.from_numpy(oc[None,:,:])
  61. od = transforms.functional.resize(od, self.output_size, interpolation=Image.NEAREST)
  62. oc = transforms.functional.resize(oc, self.output_size, interpolation=Image.NEAREST)
  63. seg = torch.cat([od, oc], dim=0)
  64. self.segs.append(seg)
  65. #Augmentation for training set :
  66. if augment and split == 'train' :
  67. #if self.index[str(k)]['Label'] == 1 :
  68. # Flip vertically around the y axis
  69. flip_vertival_od = transforms.functional.hflip(od)
  70. flip_vertival_oc = transforms.functional.hflip(oc)
  71. flip_vertical_seg = torch.cat([flip_vertival_od, flip_vertival_oc], dim=0)
  72. self.segs.append(flip_vertical_seg)
  73. self.labels.append(curr_label)
  74. # Flip horizontally around the x axis
  75. flip_horizontal_od = transforms.functional.vflip(od)
  76. flip_horizontal_oc = transforms.functional.vflip(oc)
  77. flip_horizontal_seg = torch.cat([flip_horizontal_od, flip_horizontal_oc], dim=0)
  78. self.segs.append(flip_horizontal_seg)
  79. self.labels.append(curr_label)
  80. # Random noise. We won't add noise on segmentation, it has no sense.
  81. # We will just add again the initial segmentation
  82. self.segs.append(seg)
  83. self.labels.append(curr_label)
  84. # Rotation +5°
  85. rot_pos_od = transforms.functional.rotate(od, angle= 5)
  86. rot_pos_oc = transforms.functional.rotate(oc, angle= 5)
  87. rot_pos_seg = torch.cat([rot_pos_od, rot_pos_oc], dim=0)
  88. self.segs.append(rot_pos_seg)
  89. self.labels.append(curr_label)
  90. # Rotation -5°
  91. rot_neg_od = transforms.functional.rotate(od, angle= -5)
  92. rot_neg_oc = transforms.functional.rotate(oc, angle= -5)
  93. rot_neg_seg = torch.cat([rot_neg_od, rot_neg_oc], dim=0)
  94. self.segs.append(rot_neg_seg)
  95. self.labels.append(curr_label)
  96.  
  97.  
  98. print('Succesfully loaded {} dataset.'.format(split) + ' '*50)
  99.  
  100.  
  101. def __len__(self):
  102. return len(self.images)
  103.  
  104. def __getitem__(self, idx):
  105. # Image
  106. img = self.images[idx]
  107.  
  108. # Return only images for 'test' set
  109. if self.split == 'test':
  110. return img
  111.  
  112. # Else, images and ground truth
  113. else:
  114. # Label
  115. lab = torch.tensor(self.labels[idx], dtype=torch.float32)
  116.  
  117. # Segmentation masks
  118. seg = self.segs[idx]
  119.  
  120. # Fovea localization
  121. # f_x = self.index[str(idx)]['Fovea_X']
  122. # f_y = self.index[str(idx)]['Fovea_Y']
  123. # fov = torch.FloatTensor([f_x, f_y])
  124.  
  125. return img, lab, seg #fov, self.index[str(idx)]['ImgName']
RAW Paste Data