Guest User

Untitled

a guest
Feb 25th, 2018
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.00 KB | None | 0 0
  1. class LandMarkRecognition(Dataset):
  2.  
  3. def __init__(self, root_dir, csv_file, transform=None):
  4. self.landmarks_csv = pd.read_csv(csv_file)
  5. self.root_dir = root_dir
  6. self.transform = transform
  7. self.image_names = [i for i in sorted(os.listdir(self.root_dir)) if i.endswith('.jpg')]
  8. # print(len(self.image_names))
  9. def __len__(self):
  10. return len(self.image_names)
  11.  
  12. def __getitem__(self, idx):
  13. # print("index: %d, size: %d" %(idx, len(self.image_names)))
  14. img_name = os.path.join(self.root_dir, self.image_names[idx])
  15.  
  16. # image = io.imread(img_name)
  17. # image = Image.fromarray(np.uint8(image)).convert('RGB')
  18. # pdb.set_trace()
  19. image = Image.open(img_name).convert('RGB')
  20.  
  21. index = self.landmarks_csv.index[self.landmarks_csv['id'] == img_name.split('.')[0].split('/')[-1]][0]
  22. landmarks = self.landmarks_csv.iloc[index]['landmark_id']
  23. landmarks = landmarks.astype('float')
  24.  
  25. # sample = {'image': image, 'landmarks': landmarks}
  26.  
  27. if self.transform:
  28. img = self.transform(image)
  29. return img, landmarks
Add Comment
Please, Sign In to add comment