Advertisement
warrior98

Untitled

Feb 29th, 2020
165
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.70 KB | None | 0 0
  1. import torch
  2. import os
  3. import os.path
  4. import torchvision.transforms.functional as F
  5. import numpy as np
  6. from data.vision_dataset import VisionDataset
  7. from PIL import Image
  8. from general_config.anchor_config import default_boxes
  9. from utils.preprocessing import match, prepare_gt, get_bboxes
  10.  
  11. from albumentations import (
  12.     Resize,
  13.     RandomResizedCrop,
  14.     HorizontalFlip,
  15.     Rotate,
  16.     Blur,
  17.     CLAHE,
  18.     ChannelDropout,
  19.     CoarseDropout,
  20.     GaussNoise,
  21.     RandomBrightnessContrast,
  22.     RandomGamma,
  23.     ToGray,
  24.     Compose,
  25.     BboxParams
  26. )
  27.  
  28.  
  29. class CocoDetection(VisionDataset):
  30.     """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
  31.  
  32.    Args:
  33.        root (string): Root directory where images are downloaded to.
  34.        annFile (string): Path to json annotation file.
  35.        transform (callable, optional): A function/transform that  takes in an PIL image
  36.            and returns a transformed version. E.g, ``transforms.ToTensor``
  37.        target_transform (callable, optional): A function/transform that takes in the
  38.            target and transforms it.
  39.        transforms (callable, optional): A function/transform that takes input sample and its target as entry
  40.            and returns a transformed version.
  41.  
  42.    We are using the COCO API on top of which we build our custom data processing
  43.    """
  44.  
  45.     def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, augmentation=True, params=None):
  46.         super().__init__(root, transforms, transform, target_transform)
  47.         from pycocotools.coco import COCO
  48.         self.coco = COCO(annFile)
  49.         self.ids = list(sorted(self.coco.imgs.keys()))
  50.         self.augmentation = augmentation
  51.         self.params = params
  52.  
  53.         self.anchors_ltrb = default_boxes(order='ltrb')
  54.         self.anchors_xywh = default_boxes(order='xywh')
  55.  
  56.         self.augmentations = self.get_aug([RandomResizedCrop(height=self.params.input_height,
  57.                                                              width=self.params.input_width, scale=(0.4, 1.0)),
  58.                                            HorizontalFlip(), Rotate(limit=10),
  59.                                            Blur(p=0.2), CLAHE(p=0.25), ChannelDropout(p=0.1),
  60.                                            CoarseDropout(max_holes=8, max_height=20, max_width=20),
  61.                                            GaussNoise(p=0.15), RandomBrightnessContrast(),
  62.                                            RandomGamma(), ToGray(p=0.25),
  63.                                            ], min_visibility=0.15)
  64.  
  65.         self.just_resize = self.get_aug(
  66.             [Resize(height=self.params.input_height, width=self.params.input_width)])
  67.  
  68.     def __getitem__(self, batched_indices):
  69.         """
  70.        return B x C x H x W image tensor and [B x img_bboxes, B x img_classes]
  71.        """
  72.         imgs, targets_bboxes, targets_classes, image_info = [], [], [], []
  73.         for index in batched_indices:
  74.             coco = self.coco
  75.             img_id = self.ids[index]
  76.             ann_ids = coco.getAnnIds(imgIds=img_id)
  77.             target = coco.loadAnns(ann_ids)
  78.             path = coco.loadImgs(img_id)[0]['file_name']
  79.             img = Image.open(os.path.join(self.root, path)).convert('RGB')
  80.             orig_width, orig_height = img.size
  81.  
  82.             # get useful annotations
  83.             bboxes, category_ids = get_bboxes(target)
  84.             bboxes, category_ids = self.check_bbox_validity(
  85.                 bboxes, category_ids, orig_width, orig_height)
  86.             if len(bboxes) == 0:
  87.                 continue
  88.  
  89.             album_annotation = {'image': np.array(
  90.                 img), 'bboxes': bboxes, 'category_id': category_ids}
  91.             if self.augmentation:
  92.                 transform_result = self.augmentations(**album_annotation)
  93.             else:
  94.                 transform_result = self.just_resize(**album_annotation)
  95.             image, bboxes, category_ids = transform_result.values()
  96.  
  97.             # bring bboxes to correct format and check they are valid
  98.             target = prepare_gt(image, bboxes, category_ids)
  99.  
  100.             # get image in right format - normalized tensor
  101.             image = F.to_tensor(image)
  102.             image = F.normalize(image, mean=[0.485, 0.456, 0.406],
  103.                                 std=[0.229, 0.224, 0.225])
  104.  
  105.             # #anchors x 4 and #anchors x 1
  106.             gt_bbox, gt_class = match(self.anchors_ltrb, self.anchors_xywh,
  107.                                       target[0], target[1], self.params)
  108.  
  109.             imgs.append(image)
  110.             targets_bboxes.append(gt_bbox)
  111.             targets_classes.append(gt_class)
  112.             image_info.append((img_id, (orig_width, orig_height)))
  113.  
  114.         # B x C x H x W
  115.         batch_images = torch.stack(imgs)
  116.  
  117.         # B x #anchors x 4 and 1 respectively
  118.         batch_bboxes = torch.stack(targets_bboxes)
  119.         batch_class_ids = torch.stack(targets_classes)
  120.  
  121.         label = (batch_bboxes, batch_class_ids)
  122.  
  123.         return batch_images, label, image_info
  124.  
  125.     def __len__(self):
  126.         return len(self.ids)
  127.  
  128.     def get_aug(self, aug, min_area=0., min_visibility=0.3):
  129.         """
  130.        Args:
  131.        aug - set of albumentation augmentations
  132.        min_area - minimum area to keep bbox
  133.        min_visibility - minimum area percentage (to keep bbox) of original bbox after transform
  134.        """
  135.         return Compose(aug, bbox_params=BboxParams(format='coco', min_area=min_area,
  136.                                                    min_visibility=min_visibility, label_fields=['category_id']))
  137.  
  138.     # def check_bbox_validity(self, target):
  139.     #     if target[0].nelement() == 0:
  140.     #         return
  141.     #
  142.     #     eps = 0.00001
  143.     #     gt_bbox = target[0]
  144.     #
  145.     #     # x and y must be positive
  146.     #     col_1_ok = gt_bbox[:, 0] > 0
  147.     #     col_2_ok = gt_bbox[:, 1] > 0
  148.     #
  149.     #     # width and height must be strictly greater than zero
  150.     #     col_3_ok = gt_bbox[:, 2] > eps
  151.     #     col_4_ok = gt_bbox[:, 3] > eps
  152.     #
  153.     #     # rows to keep
  154.     #     ok = col_1_ok * col_2_ok * col_3_ok * col_4_ok
  155.     #     target[0] = target[0][ok]
  156.     #     target[1] = target[1][ok]
  157.  
  158.     def check_bbox_validity(self, bboxes, category_ids, width, height):
  159.         eps = 0.000001
  160.         valid_bboxes, valid_ids = [], []
  161.         for bbox, id in zip(bboxes, category_ids):
  162.             if bbox[2] * bbox[3] <= eps:
  163.                 continue
  164.             if bbox[0] <= eps or bbox[1] <= eps or (bbox[0] + bbox[2]) >= (width - eps) or (bbox[1] + bbox[3]) >= (height - eps):
  165.                 continue
  166.             valid_bboxes.append(bbox)
  167.             valid_ids.append(id)
  168.  
  169.         return valid_bboxes, valid_ids
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement