Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import os
- import os.path
- import torchvision.transforms.functional as F
- import numpy as np
- from data.vision_dataset import VisionDataset
- from PIL import Image
- from general_config.anchor_config import default_boxes
- from utils.preprocessing import match, prepare_gt, get_bboxes
- from albumentations import (
- Resize,
- RandomResizedCrop,
- HorizontalFlip,
- Rotate,
- Blur,
- CLAHE,
- ChannelDropout,
- CoarseDropout,
- GaussNoise,
- RandomBrightnessContrast,
- RandomGamma,
- ToGray,
- Compose,
- BboxParams
- )
- class CocoDetection(VisionDataset):
- """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
- Args:
- root (string): Root directory where images are downloaded to.
- annFile (string): Path to json annotation file.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.ToTensor``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
- and returns a transformed version.
- We are using the COCO API on top of which we build our custom data processing
- """
- def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, augmentation=True, params=None):
- super().__init__(root, transforms, transform, target_transform)
- from pycocotools.coco import COCO
- self.coco = COCO(annFile)
- self.ids = list(sorted(self.coco.imgs.keys()))
- self.augmentation = augmentation
- self.params = params
- self.anchors_ltrb = default_boxes(order='ltrb')
- self.anchors_xywh = default_boxes(order='xywh')
- self.augmentations = self.get_aug([RandomResizedCrop(height=self.params.input_height,
- width=self.params.input_width, scale=(0.4, 1.0)),
- HorizontalFlip(), Rotate(limit=10),
- Blur(p=0.2), CLAHE(p=0.25), ChannelDropout(p=0.1),
- CoarseDropout(max_holes=8, max_height=20, max_width=20),
- GaussNoise(p=0.15), RandomBrightnessContrast(),
- RandomGamma(), ToGray(p=0.25),
- ], min_visibility=0.15)
- self.just_resize = self.get_aug(
- [Resize(height=self.params.input_height, width=self.params.input_width)])
- def __getitem__(self, batched_indices):
- """
- return B x C x H x W image tensor and [B x img_bboxes, B x img_classes]
- """
- imgs, targets_bboxes, targets_classes, image_info = [], [], [], []
- for index in batched_indices:
- coco = self.coco
- img_id = self.ids[index]
- ann_ids = coco.getAnnIds(imgIds=img_id)
- target = coco.loadAnns(ann_ids)
- path = coco.loadImgs(img_id)[0]['file_name']
- img = Image.open(os.path.join(self.root, path)).convert('RGB')
- orig_width, orig_height = img.size
- # get useful annotations
- bboxes, category_ids = get_bboxes(target)
- bboxes, category_ids = self.check_bbox_validity(
- bboxes, category_ids, orig_width, orig_height)
- if len(bboxes) == 0:
- continue
- album_annotation = {'image': np.array(
- img), 'bboxes': bboxes, 'category_id': category_ids}
- if self.augmentation:
- transform_result = self.augmentations(**album_annotation)
- else:
- transform_result = self.just_resize(**album_annotation)
- image, bboxes, category_ids = transform_result.values()
- # bring bboxes to correct format and check they are valid
- target = prepare_gt(image, bboxes, category_ids)
- # get image in right format - normalized tensor
- image = F.to_tensor(image)
- image = F.normalize(image, mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])
- # #anchors x 4 and #anchors x 1
- gt_bbox, gt_class = match(self.anchors_ltrb, self.anchors_xywh,
- target[0], target[1], self.params)
- imgs.append(image)
- targets_bboxes.append(gt_bbox)
- targets_classes.append(gt_class)
- image_info.append((img_id, (orig_width, orig_height)))
- # B x C x H x W
- batch_images = torch.stack(imgs)
- # B x #anchors x 4 and 1 respectively
- batch_bboxes = torch.stack(targets_bboxes)
- batch_class_ids = torch.stack(targets_classes)
- label = (batch_bboxes, batch_class_ids)
- return batch_images, label, image_info
- def __len__(self):
- return len(self.ids)
- def get_aug(self, aug, min_area=0., min_visibility=0.3):
- """
- Args:
- aug - set of albumentation augmentations
- min_area - minimum area to keep bbox
- min_visibility - minimum area percentage (to keep bbox) of original bbox after transform
- """
- return Compose(aug, bbox_params=BboxParams(format='coco', min_area=min_area,
- min_visibility=min_visibility, label_fields=['category_id']))
- # def check_bbox_validity(self, target):
- # if target[0].nelement() == 0:
- # return
- #
- # eps = 0.00001
- # gt_bbox = target[0]
- #
- # # x and y must be positive
- # col_1_ok = gt_bbox[:, 0] > 0
- # col_2_ok = gt_bbox[:, 1] > 0
- #
- # # width and height must be strictly greater than zero
- # col_3_ok = gt_bbox[:, 2] > eps
- # col_4_ok = gt_bbox[:, 3] > eps
- #
- # # rows to keep
- # ok = col_1_ok * col_2_ok * col_3_ok * col_4_ok
- # target[0] = target[0][ok]
- # target[1] = target[1][ok]
- def check_bbox_validity(self, bboxes, category_ids, width, height):
- eps = 0.000001
- valid_bboxes, valid_ids = [], []
- for bbox, id in zip(bboxes, category_ids):
- if bbox[2] * bbox[3] <= eps:
- continue
- if bbox[0] <= eps or bbox[1] <= eps or (bbox[0] + bbox[2]) >= (width - eps) or (bbox[1] + bbox[3]) >= (height - eps):
- continue
- valid_bboxes.append(bbox)
- valid_ids.append(id)
- return valid_bboxes, valid_ids
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement