Advertisement
Guest User

Untitled

a guest
Jul 18th, 2019
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.89 KB | None | 0 0
  1. # --------------------------------------------------------
  2. # SiamMask
  3. # Licensed under The MIT License
  4. # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
  5. # --------------------------------------------------------
  6. import glob
  7. from tools.test import *
  8.  
  9. parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
  10.  
  11. parser.add_argument('--resume', default='', type=str, required=True,
  12. metavar='PATH',help='path to latest checkpoint (default: none)')
  13. parser.add_argument('--config', dest='config', default='config_davis.json',
  14. help='hyper-parameter of SiamMask in json format')
  15. parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
  16. parser.add_argument('--cpu', action='store_true', help='cpu mode')
  17. args = parser.parse_args()
  18.  
  19. if __name__ == '__main__':
  20. # Setup device
  21. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  22.  
  23. if torch.cuda.is_available():
  24. print("Running on GPU")
  25. else:
  26. print("Running on CPU :(")
  27.  
  28. torch.backends.cudnn.benchmark = True
  29.  
  30. # Setup Model
  31. cfg = load_config(args)
  32. from custom import Custom
  33. siammask = Custom(anchors=cfg['anchors'])
  34. if args.resume:
  35. assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
  36. siammask = load_pretrain(siammask, args.resume)
  37.  
  38. siammask.eval().to(device)
  39.  
  40. # Parse Image file
  41. img_files = sorted(glob.glob(join("/home/cin/Projects/SiamMask/data/soccer", '*.jpg')))
  42. ims = [cv2.imread(imf) for imf in img_files]
  43.  
  44. # Select ROI
  45. cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
  46. # cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
  47. try:
  48. init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
  49. x, y, w, h = init_rect
  50. except:
  51. exit()
  52.  
  53. toc = 0
  54. for f, im in enumerate(ims):
  55. tic = cv2.getTickCount()
  56. if f == 0: # init
  57. target_pos = np.array([x + w / 2, y + h / 2])
  58. target_sz = np.array([w, h])
  59. state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'], device=device) # init tracker
  60. elif f > 0: # tracking
  61. state = siamese_track(state, im, mask_enable=True, refine_enable=True, device=device) # track
  62. location = state['ploygon'].flatten()
  63. mask = state['mask'] > state['p'].seg_thr
  64.  
  65. im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
  66. cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
  67. cv2.imshow('SiamMask', im)
  68. key = cv2.waitKey(1)
  69. if key > 0:
  70. break
  71.  
  72. toc += cv2.getTickCount() - tic
  73. toc /= cv2.getTickFrequency()
  74. fps = f / toc
  75. print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement