Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class LandMarkRecognition(Dataset):
- def __init__(self, root_dir, csv_file, transform=None):
- self.landmarks_csv = pd.read_csv(csv_file)
- self.root_dir = root_dir
- self.transform = transform
- self.image_names = [i for i in sorted(os.listdir(self.root_dir)) if i.endswith('.jpg')]
- # print(len(self.image_names))
- def __len__(self):
- return len(self.image_names)
- def __getitem__(self, idx):
- # print("index: %d, size: %d" %(idx, len(self.image_names)))
- img_name = os.path.join(self.root_dir, self.image_names[idx])
- # image = io.imread(img_name)
- # image = Image.fromarray(np.uint8(image)).convert('RGB')
- # pdb.set_trace()
- image = Image.open(img_name).convert('RGB')
- index = self.landmarks_csv.index[self.landmarks_csv['id'] == img_name.split('.')[0].split('/')[-1]][0]
- landmarks = self.landmarks_csv.iloc[index]['landmark_id']
- landmarks = landmarks.astype('float')
- # sample = {'image': image, 'landmarks': landmarks}
- if self.transform:
- img = self.transform(image)
- return img, landmarks
Add Comment
Please, Sign In to add comment