Guest User

Untitled

a guest
Jun 23rd, 2018
124
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.46 KB | None | 0 0
  1. class SegDataset(Dataset):
  2. def __init__(self, csv_loc, data_dir, augments=200):
  3. self.data_dir = data_dir
  4. self.images_data = read_csv(csv_loc)
  5. self.images = self.prepare_images()
  6.  
  7. def transform(self, raw, seg):
  8. i, j, h, w = transforms.RandomCrop.get_params(
  9. raw, output_size=(128, 128))
  10. raw = trans_f.crop(raw, i, j, h, w)
  11. seg = trans_f.crop(seg, i, j, h, w)
  12.  
  13. if random.random() > 0.5:
  14. raw = trans_f.hflip(raw)
  15. seg = trans_f.hflip(seg)
  16.  
  17. if random.random() > 0.5:
  18. raw = trans_f.vflip(raw)
  19. seg = trans_f.vflip(seg)
  20.  
  21. raw = trans_f.to_tensor(raw).mul(255).float().to(device)
  22. seg = trans_f.to_tensor(seg).mul(255).long().to(device)
  23.  
  24. return {'raw': raw, 'seg': seg}
  25.  
  26. def prepare_images(self):
  27. images = []
  28.  
  29. def read_image_by_id(idx, raw_image=True):
  30. img_name = os.path.join(self.data_dir,
  31. self.images_data[idx][int(not raw_image)])
  32. return Image.open(img_name)
  33.  
  34. for idx in range(len(self.images_data)):
  35. raw = read_image_by_id(idx)
  36. seg = read_image_by_id(idx, False)
  37. for _ in range(self.augments):
  38. images.append(self.transform(raw, seg))
  39. random.shuffle(images)
  40. return images
  41.  
  42. def __len__(self):
  43. return len(self.images)
  44.  
  45. def __getitem__(self, idx):
  46. return self.images[idx]
Add Comment
Please, Sign In to add comment