Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class RefugeDataset(Dataset):
- def __init__(self, root_dir, split='train', output_size=(256,256), augment=False):
- # Define attributes
- self.output_size = output_size
- self.root_dir = root_dir
- self.split = split
- self.labels = []
- # Transforms
- trans_img = transforms.Compose([
- transforms.ToTensor(),
- transforms.Resize(self.output_size, interpolation=Image.BILINEAR),
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- ])
- # Load data index
- with open(os.path.join(self.root_dir, self.split, 'index.json')) as f:
- self.index = json.load(f)
- self.images = []
- for k in range(len(self.index)):
- print('Loading {} image {}/{}...'.format(split, k, len(self.index)), end='\r')
- img_name = os.path.join(self.root_dir, self.split, 'images', self.index[str(k)]['ImgName'])
- img = np.array(Image.open(img_name).convert('RGB'))
- img = trans_img(img)
- self.images.append(img)
- #Augmentation for training set :
- if augment and split == 'train' :
- #if self.index[str(k)]['Label'] == 1 :
- # Flip vertically around the y axis
- flip_vertival_img = transforms.functional.hflip(img)
- self.images.append(flip_vertival_img)
- # Flip horizontally around the x axis
- flip_horizontal_img = transforms.functional.vflip(img)
- self.images.append(flip_horizontal_img)
- # Random noise
- noisy_img = transforms.functional.gaussian_blur(img,kernel_size= [5,5])
- self.images.append(noisy_img)
- # Rotation +5°
- rot_pos_img = transforms.functional.rotate(img,angle = 5)
- self.images.append(rot_pos_img)
- # Rotation -5°
- rot_neg_img = transforms.functional.rotate(img,angle = -5)
- self.images.append(rot_neg_img)
- # Load ground truth for 'train' and 'val' sets
- if split != 'test':
- self.segs = []
- for k in range(len(self.index)):
- curr_label = self.index[str(k)]['Label']
- self.labels.append(curr_label)
- print('Loading {} segmentation {}/{}...'.format(split, k, len(self.index)), end='\r')
- seg_name = os.path.join(self.root_dir, self.split, 'gts', self.index[str(k)]['ImgName'].split('.')[0]+'.bmp')
- seg = np.array(Image.open(seg_name)).copy()
- seg = 255. - seg
- od = (seg>=127.).astype(np.float32)
- oc = (seg>=250.).astype(np.float32)
- od = torch.from_numpy(od[None,:,:])
- oc = torch.from_numpy(oc[None,:,:])
- od = transforms.functional.resize(od, self.output_size, interpolation=Image.NEAREST)
- oc = transforms.functional.resize(oc, self.output_size, interpolation=Image.NEAREST)
- seg = torch.cat([od, oc], dim=0)
- self.segs.append(seg)
- #Augmentation for training set :
- if augment and split == 'train' :
- #if self.index[str(k)]['Label'] == 1 :
- # Flip vertically around the y axis
- flip_vertival_od = transforms.functional.hflip(od)
- flip_vertival_oc = transforms.functional.hflip(oc)
- flip_vertical_seg = torch.cat([flip_vertival_od, flip_vertival_oc], dim=0)
- self.segs.append(flip_vertical_seg)
- self.labels.append(curr_label)
- # Flip horizontally around the x axis
- flip_horizontal_od = transforms.functional.vflip(od)
- flip_horizontal_oc = transforms.functional.vflip(oc)
- flip_horizontal_seg = torch.cat([flip_horizontal_od, flip_horizontal_oc], dim=0)
- self.segs.append(flip_horizontal_seg)
- self.labels.append(curr_label)
- # Random noise. We won't add noise on segmentation, it has no sense.
- # We will just add again the initial segmentation
- self.segs.append(seg)
- self.labels.append(curr_label)
- # Rotation +5°
- rot_pos_od = transforms.functional.rotate(od, angle= 5)
- rot_pos_oc = transforms.functional.rotate(oc, angle= 5)
- rot_pos_seg = torch.cat([rot_pos_od, rot_pos_oc], dim=0)
- self.segs.append(rot_pos_seg)
- self.labels.append(curr_label)
- # Rotation -5°
- rot_neg_od = transforms.functional.rotate(od, angle= -5)
- rot_neg_oc = transforms.functional.rotate(oc, angle= -5)
- rot_neg_seg = torch.cat([rot_neg_od, rot_neg_oc], dim=0)
- self.segs.append(rot_neg_seg)
- self.labels.append(curr_label)
- print('Succesfully loaded {} dataset.'.format(split) + ' '*50)
- def __len__(self):
- return len(self.images)
- def __getitem__(self, idx):
- # Image
- img = self.images[idx]
- # Return only images for 'test' set
- if self.split == 'test':
- return img
- # Else, images and ground truth
- else:
- # Label
- lab = torch.tensor(self.labels[idx], dtype=torch.float32)
- # Segmentation masks
- seg = self.segs[idx]
- # Fovea localization
- # f_x = self.index[str(idx)]['Fovea_X']
- # f_y = self.index[str(idx)]['Fovea_Y']
- # fov = torch.FloatTensor([f_x, f_y])
- return img, lab, seg #fov, self.index[str(idx)]['ImgName']
Advertisement
Add Comment
Please, Sign In to add comment