Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import cv2
- import torch
- import click
- import PIL.Image
- import colorsys
- import matplotlib.pyplot as plt
- from torchvision.models.detection import maskrcnn_resnet50_fpn
- from torchvision.transforms import Normalize, ToTensor
- NORMALIZATION_STATS = {
- 'mean': [0.485, 0.456, 0.406],
- 'std': [0.229, 0.224, 0.225]
- }
- def load_image(path):
- img = PIL.Image.open(path)
- img_raw = np.array(img, dtype=np.uint8)
- img = ToTensor()(img)
- img = Normalize(**NORMALIZATION_STATS)(img)
- return img, img_raw
- @click.command()
- @click.option('-d', '--device', default='cuda', type=str)
- @click.argument('path')
- def main(path, device):
- device = torch.device(device)
- img, img_raw = load_image(path)
- img = img.to(device)
- model = maskrcnn_resnet50_fpn(pretrained=True, rpn_nms_thresh=0.7)
- model = model.to(device).eval()
- with torch.no_grad():
- out = model(img.unsqueeze(0))[0]
- boxes, labels, masks = out['boxes'], out['labels'], out['masks']
- # def _func(x):
- # return x
- #
- # trasformed_boxes = [_func(box) for box in boxes if True]
- # trasformed_boxes = map(lambda x: x ** 2, boxes)
- # l = []
- # for box in boxes:
- # l.append(_func(box))
- # boxes = l
- # out = dict(())
- img = img.cpu()
- out_masks = np.zeros_like(img, dtype=np.uint8)
- for idx, (box, lab, mask) in enumerate(zip(boxes, labels, masks)):
- start_x, start_y, width, height = box
- p1 = (start_x, start_y)
- p2 = (width, height)
- color_hsv = {'h': idx / len(masks), 's': 1., 'v': 1.}
- color_rgb = tuple(map(lambda x: int(x*255), colorsys.hsv_to_rgb(**color_hsv)))
- img_after = cv2.rectangle(img_raw, p1, p2, color_rgb)
- plt.imshow(img_after)
- plt.show()
- # cv2.addWeighted()
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment