Guest User

Untitled

a guest
Nov 17th, 2019
87
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import cv2
  3. import torch
  4. import click
  5. import PIL.Image
  6. import colorsys
  7. import matplotlib.pyplot as plt
  8. from torchvision.models.detection import maskrcnn_resnet50_fpn
  9. from torchvision.transforms import Normalize, ToTensor
  10.  
  11.  
  12. NORMALIZATION_STATS = {
  13. 'mean': [0.485, 0.456, 0.406],
  14. 'std': [0.229, 0.224, 0.225]
  15. }
  16.  
  17.  
  18. def load_image(path):
  19. img = PIL.Image.open(path)
  20. img_raw = np.array(img, dtype=np.uint8)
  21. img = ToTensor()(img)
  22. img = Normalize(**NORMALIZATION_STATS)(img)
  23. return img, img_raw
  24.  
  25.  
  26. @click.command()
  27. @click.option('-d', '--device', default='cuda', type=str)
  28. @click.argument('path')
  29. def main(path, device):
  30. device = torch.device(device)
  31. img, img_raw = load_image(path)
  32. img = img.to(device)
  33. model = maskrcnn_resnet50_fpn(pretrained=True, rpn_nms_thresh=0.7)
  34. model = model.to(device).eval()
  35. with torch.no_grad():
  36. out = model(img.unsqueeze(0))[0]
  37. boxes, labels, masks = out['boxes'], out['labels'], out['masks']
  38.  
  39. # def _func(x):
  40. # return x
  41. #
  42. # trasformed_boxes = [_func(box) for box in boxes if True]
  43. # trasformed_boxes = map(lambda x: x ** 2, boxes)
  44. # l = []
  45. # for box in boxes:
  46. # l.append(_func(box))
  47. # boxes = l
  48. # out = dict(())
  49. img = img.cpu()
  50. out_masks = np.zeros_like(img, dtype=np.uint8)
  51.  
  52. for idx, (box, lab, mask) in enumerate(zip(boxes, labels, masks)):
  53. start_x, start_y, width, height = box
  54. p1 = (start_x, start_y)
  55. p2 = (width, height)
  56. color_hsv = {'h': idx / len(masks), 's': 1., 'v': 1.}
  57. color_rgb = tuple(map(lambda x: int(x*255), colorsys.hsv_to_rgb(**color_hsv)))
  58. img_after = cv2.rectangle(img_raw, p1, p2, color_rgb)
  59.  
  60. plt.imshow(img_after)
  61. plt.show()
  62.  
  63. # cv2.addWeighted()
  64.  
  65.  
  66. if __name__ == '__main__':
  67. main()
RAW Paste Data