Advertisement
Guest User

Untitled

a guest
Apr 2nd, 2020
162
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.83 KB | None | 0 0
  1. class BasicDataset(Dataset):
  2.     def __init__(self, imgs_dir, masks_dir, scale=1):
  3.         self.imgs_dir = imgs_dir
  4.         self.masks_dir = masks_dir
  5.         self.scale = scale
  6.         assert 0 < scale <= 1, 'Scale must be between 0 and 1'
  7.  
  8.         self.ids = [file.split('-')[0] for file in listdir(imgs_dir)
  9.                     if not file.startswith('.')]
  10.         logging.info(f'Creating dataset with {len(self.ids)} examples')
  11.  
  12.     def __len__(self):
  13.         return len(self.ids)
  14.  
  15.     @classmethod
  16.     def preprocess(cls, pil_img, scale):
  17.         w, h = pil_img.size
  18.         newW, newH = int(scale * w), int(scale * h)
  19.         assert newW > 0 and newH > 0, 'Scale is too small'
  20.         pil_img = pil_img.resize((newW, newH))
  21.  
  22.         img_nd = np.array(pil_img)
  23.  
  24.         if len(img_nd.shape) == 2:
  25.             img_nd = np.expand_dims(img_nd, axis=2)
  26.  
  27.         # HWC to CHW
  28.         img_trans = img_nd.transpose((2, 0, 1))
  29.         if img_trans.max() > 1:
  30.             img_trans = img_trans / 255.
  31.  
  32.         return img_trans
  33.  
  34.     def __getitem__(self, i):
  35.         idx = self.ids[i]
  36.         mask_file = glob(self.masks_dir + idx + '*')
  37.         img_file = glob(self.imgs_dir + idx + '*')
  38.  
  39.         assert len(mask_file) == 1, \
  40.             f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
  41.         assert len(img_file) == 1, \
  42.             f'Either no image or multiple images found for the ID {idx}: {img_file}'
  43.         mask = Image.open(mask_file[0])
  44.         img = Image.open(img_file[0])
  45.  
  46.         assert img.size == mask.size, \
  47.             f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
  48.  
  49.         img = self.preprocess(img, self.scale)
  50.         mask = self.preprocess(mask, self.scale)
  51.  
  52.         return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement