Advertisement
Guest User

Untitled

a guest
Oct 17th, 2019
133
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.39 KB | None | 0 0
  1. import os
  2. import torch
  3. import torch.utils.data
  4. import torchvision
  5. from PIL import Image
  6. from pycocotools.coco import COCO
  7.  
  8. class myOwnDataset(torch.utils.data.Dataset):
  9. def __init__(self, root, annotation, transforms=None):
  10. self.root = root
  11. self.transforms = transforms
  12. self.coco = COCO(annotation)
  13. self.ids = list(sorted(self.coco.imgs.keys()))
  14.  
  15. def __getitem__(self, index):
  16. # Own coco file
  17. coco = self.coco
  18. # Image ID
  19. img_id = self.ids[index]
  20. # List: get annotation id from coco
  21. ann_ids = coco.getAnnIds(imgIds=img_id)
  22. # Dictionary: target coco_annotation file for an image
  23. coco_annotation = coco.loadAnns(ann_ids)
  24. # path for input image
  25. path = coco.loadImgs(img_id)[0]['file_name']
  26. # open the input image
  27. img = Image.open(os.path.join(self.root, path))
  28.  
  29. # number of objects in the image
  30. num_objs = len(coco_annotation)
  31.  
  32. # Bounding boxes for objects
  33. # In coco format, bbox = [xmin, ymin, width, height]
  34. # In pytorch, the input should be [xmin, ymin, xmax, ymax]
  35. boxes = []
  36. for i in range(num_objs):
  37. xmin = coco_annotation[i]['bbox'][0]
  38. ymin = coco_annotation[i]['bbox'][1]
  39. xmax = xmin + coco_annotation[i]['bbox'][2]
  40. ymax = ymin + coco_annotation[i]['bbox'][3]
  41. boxes.append([xmin, ymin, xmax, ymax])
  42. boxes = torch.as_tensor(boxes, dtype=torch.float32)
  43. # Labels (In my case, I only one class: target class or background)
  44. labels = torch.ones((num_objs,), dtype=torch.int64)
  45. # Tensorise img_id
  46. img_id = torch.tensor([img_id])
  47. # Size of bbox (Rectangular)
  48. areas = []
  49. for i in range(num_objs):
  50. areas.append(coco_annotation[i]['area'])
  51. areas = torch.as_tensor(areas, dtype=torch.float32)
  52. # Iscrowd
  53. iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
  54.  
  55. # Annotation is in dictionary format
  56. my_annotation = {}
  57. my_annotation["boxes"] = boxes
  58. my_annotation["labels"] = labels
  59. my_annotation["image_id"] = img_id
  60. my_annotation["area"] = areas
  61. my_annotation["iscrowd"] = iscrowd
  62.  
  63. if self.transforms is not None:
  64. img = self.transforms(img)
  65.  
  66. return img, my_annotation
  67.  
  68. def __len__(self):
  69. return len(self.ids)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement