Advertisement
Guest User

Untitled

a guest
Mar 20th, 2019
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.81 KB | None | 0 0
  1. import cv2
  2. import torch
  3. import torch.nn.functional as F
  4. import scipy
  5. import numpy as np
  6.  
  7.  
  8. def draw_circle(canvas, point, radius):
  9. cv2.circle(canvas, (int(round(point[0])), int(round(point[1]))), radius, color=255, thickness=-1)
  10.  
  11.  
  12. def img_grad_col(img):
  13. img = img.unsqueeze(0).unsqueeze(0)
  14. a = torch.Tensor([[1, 0, -1],
  15. [2, 0, -2],
  16. [1, 0, -1]])
  17.  
  18. a = a.view((1, 1, 3, 3))
  19. g_x = F.conv2d(img, a)
  20. g_x = g_x.squeeze()
  21. return g_x
  22.  
  23.  
  24. def img_grad_row(img):
  25. img = img.unsqueeze(0).unsqueeze(0)
  26. b = torch.Tensor([[1, 2, 1],
  27. [0, 0, 0],
  28. [-1, -2, -1]])
  29.  
  30. b = b.view((1, 1, 3, 3))
  31. g_y = F.conv2d(img, b)
  32. g_y = g_y.squeeze()
  33. return g_y
  34.  
  35.  
  36. def theta_for_patch_center(img_shape, window_size, patch_center):
  37.  
  38. scale_x = torch.tensor(window_size[1] / img_shape[1])
  39. scale_y = torch.tensor(window_size[0] / img_shape[0])
  40.  
  41. offset_x = 2 * patch_center[1] / (img_shape[1] - 1) - 1
  42. offset_y = 2 * patch_center[0] / (img_shape[0] - 1) - 1
  43. zero_tensor = torch.tensor(0.0)
  44. theta = torch.stack([
  45. scale_x, zero_tensor, offset_x,
  46. zero_tensor, scale_y, offset_y
  47. ]).view(1, 2, 3)
  48.  
  49. return theta
  50.  
  51.  
  52. def cut_patch(img, window_size, patch_center, inter_mode="bilinear"):
  53. theta = theta_for_patch_center(img.shape, window_size, patch_center)
  54. grid = F.affine_grid(theta, [1, 1, window_size[0], window_size[1]])
  55.  
  56. img = img.unsqueeze(0).unsqueeze(0)
  57. sampled = F.grid_sample(img, grid, mode=inter_mode)
  58. sampled = sampled.squeeze()
  59. return sampled
  60.  
  61.  
  62. def compute_lk_error(frame0, frame1, window_size, patch_center, p):
  63. vals0 = cut_patch(frame0, window_size, patch_center)
  64. vals1 = cut_patch(frame1, window_size, patch_center + p)
  65.  
  66. vals0 = vals0.view(-1, 1)
  67. vals1 = vals1.view(-1, 1)
  68.  
  69. diff = vals0 - vals1
  70. return diff.transpose(0, 1) @ diff
  71.  
  72.  
  73. def compute_jacobian(frame0_x, frame0_y, window_size, patch_center):
  74. dxs = cut_patch(frame0_x, window_size, patch_center).view(-1)
  75. dys = cut_patch(frame0_y, window_size, patch_center).view(-1)
  76. return torch.stack((dxs, dys), dim=1)
  77.  
  78.  
  79. def perform_lk(frame0, frame1, window_size, patch_center, p0):
  80. p = p0.clone()
  81. print(p)
  82.  
  83. frame0_r = img_grad_row(frame0)
  84. frame0_c = img_grad_col(frame0)
  85.  
  86. jacobian = compute_jacobian(frame0_r, frame0_c, window_size, patch_center)
  87. hessian = jacobian.transpose(0, 1) @ jacobian
  88. hessian_inv = hessian.inverse()
  89.  
  90. n_iters = 200
  91. for iter_ind in range(n_iters):
  92. err = compute_lk_error(frame0, frame1, window_size, patch_center, p)
  93. print(f"{iter_ind + 1}/{n_iters} Loss = {err}")
  94. if err < 1e-3:
  95. break
  96.  
  97. vals0 = cut_patch(frame0, window_size, patch_center).view(-1, 1)
  98. vals1 = cut_patch(frame1, window_size, patch_center + p).view(-1, 1)
  99. grad = jacobian.transpose(0, 1) @ (vals0 - vals1)
  100. dp = hessian_inv @ grad
  101. p -= dp.view(2)
  102. print(p)
  103. return p
  104.  
  105.  
  106. def main():
  107. canvas1 = np.zeros((128, 128), dtype=np.uint8)
  108. canvas2 = np.zeros((128, 128), dtype=np.uint8)
  109.  
  110. # canvas1 = np.zeros((256, 256), dtype=np.uint8)
  111. # canvas2 = np.zeros((256, 256), dtype=np.uint8)
  112.  
  113. draw_circle(canvas1, (64, 64), 5)
  114. draw_circle(canvas2, (66, 64), 5)
  115.  
  116. canvas1_torch = torch.FloatTensor(canvas1)
  117. canvas2_torch = torch.FloatTensor(canvas2)
  118.  
  119. window_size = (11, 11)
  120. patch_center = torch.FloatTensor([64, 64])
  121. patch_center.requires_grad_(True)
  122. p = torch.FloatTensor([-1, 0])
  123.  
  124. print(compute_lk_error(canvas1_torch, canvas2_torch, window_size, patch_center, p))
  125. print(perform_lk(canvas1_torch, canvas2_torch, window_size, patch_center, p))
  126.  
  127. # cv2.imshow('', canvas1)
  128. # cv2.imshow('r', canvas1_r)
  129. # cv2.imshow('c', canvas1_c)
  130. # cv2.waitKey()
  131.  
  132.  
  133. if __name__ == '__main__':
  134. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement