Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class SegDataset(Dataset):
- def __init__(self, csv_loc, data_dir, augments=200):
- self.data_dir = data_dir
- self.images_data = read_csv(csv_loc)
- self.images = self.prepare_images()
- def transform(self, raw, seg):
- i, j, h, w = transforms.RandomCrop.get_params(
- raw, output_size=(128, 128))
- raw = trans_f.crop(raw, i, j, h, w)
- seg = trans_f.crop(seg, i, j, h, w)
- if random.random() > 0.5:
- raw = trans_f.hflip(raw)
- seg = trans_f.hflip(seg)
- if random.random() > 0.5:
- raw = trans_f.vflip(raw)
- seg = trans_f.vflip(seg)
- raw = trans_f.to_tensor(raw).mul(255).float().to(device)
- seg = trans_f.to_tensor(seg).mul(255).long().to(device)
- return {'raw': raw, 'seg': seg}
- def prepare_images(self):
- images = []
- def read_image_by_id(idx, raw_image=True):
- img_name = os.path.join(self.data_dir,
- self.images_data[idx][int(not raw_image)])
- return Image.open(img_name)
- for idx in range(len(self.images_data)):
- raw = read_image_by_id(idx)
- seg = read_image_by_id(idx, False)
- for _ in range(self.augments):
- images.append(self.transform(raw, seg))
- random.shuffle(images)
- return images
- def __len__(self):
- return len(self.images)
- def __getitem__(self, idx):
- return self.images[idx]
Add Comment
Please, Sign In to add comment