SHARE
TWEET

Untitled

a guest Nov 17th, 2019 84 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import cv2
  3. import torch
  4. import click
  5. import PIL.Image
  6. import colorsys
  7. import matplotlib.pyplot as plt
  8. from torchvision.models.detection import maskrcnn_resnet50_fpn
  9. from torchvision.transforms import Normalize, ToTensor
  10.  
  11.  
  12. NORMALIZATION_STATS = {
  13.     'mean': [0.485, 0.456, 0.406],
  14.     'std': [0.229, 0.224, 0.225]
  15. }
  16.  
  17.  
  18. def load_image(path):
  19.     img = PIL.Image.open(path)
  20.     img_raw = np.array(img, dtype=np.uint8)
  21.     img = ToTensor()(img)
  22.     img = Normalize(**NORMALIZATION_STATS)(img)
  23.     return img, img_raw
  24.  
  25.  
  26. @click.command()
  27. @click.option('-d', '--device', default='cuda', type=str)
  28. @click.argument('path')
  29. def main(path, device):
  30.     device = torch.device(device)
  31.     img, img_raw = load_image(path)
  32.     img = img.to(device)
  33.     model = maskrcnn_resnet50_fpn(pretrained=True, rpn_nms_thresh=0.7)
  34.     model = model.to(device).eval()
  35.     with torch.no_grad():
  36.         out = model(img.unsqueeze(0))[0]
  37.     boxes, labels, masks = out['boxes'], out['labels'], out['masks']
  38.  
  39.     # def _func(x):
  40.     #     return x
  41.     #
  42.     # trasformed_boxes = [_func(box) for box in boxes if True]
  43.     # trasformed_boxes = map(lambda x: x ** 2, boxes)
  44.     # l = []
  45.     # for box in boxes:
  46.     #     l.append(_func(box))
  47.     # boxes = l
  48.     # out = dict(())
  49.     img = img.cpu()
  50.     out_masks = np.zeros_like(img, dtype=np.uint8)
  51.  
  52.     for idx, (box, lab, mask) in enumerate(zip(boxes, labels, masks)):
  53.         start_x, start_y, width, height = box
  54.         p1 = (start_x, start_y)
  55.         p2 = (width, height)
  56.         color_hsv = {'h': idx / len(masks), 's': 1., 'v': 1.}
  57.         color_rgb = tuple(map(lambda x: int(x*255), colorsys.hsv_to_rgb(**color_hsv)))
  58.         img_after = cv2.rectangle(img_raw, p1, p2, color_rgb)
  59.  
  60.     plt.imshow(img_after)
  61.     plt.show()
  62.  
  63.         # cv2.addWeighted()
  64.  
  65.  
  66. if __name__ ==  '__main__':
  67.     main()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top