Guest User

Untitled

a guest
Nov 17th, 2019
183
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.86 KB | None | 0 0
  1. import numpy as np
  2. import cv2
  3. import torch
  4. import click
  5. import PIL.Image
  6. import colorsys
  7. import matplotlib.pyplot as plt
  8. from torchvision.models.detection import maskrcnn_resnet50_fpn
  9. from torchvision.transforms import Normalize, ToTensor
  10.  
  11.  
  12. NORMALIZATION_STATS = {
  13. 'mean': [0.485, 0.456, 0.406],
  14. 'std': [0.229, 0.224, 0.225]
  15. }
  16.  
  17.  
  18. def load_image(path):
  19. img = PIL.Image.open(path)
  20. img_raw = np.array(img, dtype=np.uint8)
  21. img = ToTensor()(img)
  22. img = Normalize(**NORMALIZATION_STATS)(img)
  23. return img, img_raw
  24.  
  25.  
  26. @click.command()
  27. @click.option('-d', '--device', default='cuda', type=str)
  28. @click.argument('path')
  29. def main(path, device):
  30. device = torch.device(device)
  31. img, img_raw = load_image(path)
  32. img = img.to(device)
  33. model = maskrcnn_resnet50_fpn(pretrained=True, rpn_nms_thresh=0.7)
  34. model = model.to(device).eval()
  35. with torch.no_grad():
  36. out = model(img.unsqueeze(0))[0]
  37. boxes, labels, masks = out['boxes'], out['labels'], out['masks']
  38.  
  39. # def _func(x):
  40. # return x
  41. #
  42. # trasformed_boxes = [_func(box) for box in boxes if True]
  43. # trasformed_boxes = map(lambda x: x ** 2, boxes)
  44. # l = []
  45. # for box in boxes:
  46. # l.append(_func(box))
  47. # boxes = l
  48. # out = dict(())
  49. img = img.cpu()
  50. out_masks = np.zeros_like(img, dtype=np.uint8)
  51.  
  52. for idx, (box, lab, mask) in enumerate(zip(boxes, labels, masks)):
  53. start_x, start_y, width, height = box
  54. p1 = (start_x, start_y)
  55. p2 = (width, height)
  56. color_hsv = {'h': idx / len(masks), 's': 1., 'v': 1.}
  57. color_rgb = tuple(map(lambda x: int(x*255), colorsys.hsv_to_rgb(**color_hsv)))
  58. img_after = cv2.rectangle(img_raw, p1, p2, color_rgb)
  59.  
  60. plt.imshow(img_after)
  61. plt.show()
  62.  
  63. # cv2.addWeighted()
  64.  
  65.  
  66. if __name__ == '__main__':
  67. main()
Advertisement
Add Comment
Please, Sign In to add comment