Guest User

Untitled

a guest
Nov 16th, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.38 KB | None | 0 0
  1. import numpy as np
  2.  
  3.  
  4.  
  5. class PoseMatching():
  6. """
  7. num_joints : 존재하는 joint 갯수
  8. pck_ratio_thresh : (0~1) 값이 낮으면 더욱 비슷해야 정답으로 인정.
  9. """
  10. def __init__(self,num_joints = 16, pck_ratio_thresh = 0.1):
  11. self.num_joints = num_joints
  12. self.pck_ratio_thresh = pck_ratio_thresh
  13. pass
  14.  
  15. def _enclosing_rect(self,points):
  16. xs = points[0::2]
  17. ys = points[1::2]
  18. return np.array([np.amin(xs), np.amin(ys), np.amax(xs), np.amax(ys)])
  19.  
  20. def rect_size(self,rect):
  21. return np.array([rect[2] - rect[0], rect[3] - rect[1]])
  22.  
  23. """
  24. input
  25. prediction_joint : 예측한 좌표
  26. groundtruth_joint : 정답 좌표
  27. gt_present_joints : 정답 좌표에서 joint 가 존재하는지 안하는지
  28.  
  29. output
  30. pck : 0~100. 100 이면 모두 정확히 일치.
  31. """
  32. def _eval_pck(self,prediction_joint, groundtruth_joint,gt_present_joints,gt_rect):
  33.  
  34. pred_joints = np.zeros((self.num_joints, 2))
  35. gt_joints = np.zeros((self.num_joints, 2))
  36.  
  37. if len(groundtruth_joint) == 0:
  38. raise Exception("gt joint is 0 .")
  39.  
  40. gt_joints[:, 0] = groundtruth_joint[0::2]
  41. gt_joints[:, 1] = groundtruth_joint[1::2]
  42.  
  43. pck_thresh = self.pck_ratio_thresh*np.amax(self.rect_size(gt_rect))
  44.  
  45. pred_joints[:, 0] = prediction_joint[0::2]
  46. pred_joints[:, 1] = prediction_joint[1::2]
  47.  
  48. dists = np.sqrt(np.sum((pred_joints - gt_joints)**2, axis=1))
  49. correct = dists <= pck_thresh
  50.  
  51. num_all = np.sum(gt_present_joints, axis=0)
  52.  
  53. num_correct = np.zeros((self.num_joints, ))
  54. for j_id in range(self.num_joints):
  55. num_correct[j_id] = np.sum(correct[gt_present_joints[j_id] == 1, j_id], axis=0)
  56.  
  57. pck = np.sum(num_correct,axis=0)/num_all*100.0
  58.  
  59. return pck
  60.  
  61.  
  62. if __name__ == "__main__":
  63.  
  64. num_joint = 3
  65. pck_ratio_thresh = 0.1
  66. match_mod = PoseMatching(num_joints=num_joint, pck_ratio_thresh=pck_ratio_thresh)
  67. # 템플릿 생성시 아래 한번 call
  68. gt_present_joints = [1,0,1]
  69. groundtruth_joint = [11,21,13,26,45,50]
  70. gt_box = match_mod._enclosing_rect(groundtruth_joint)
  71.  
  72.  
  73. # prediction 할때마다 아래처럼 matching 정확도 측정
  74. prediction_joint = [10,20,13,25,40,45]
  75. pck = match_mod._eval_pck(prediction_joint,groundtruth_joint,gt_present_joints, gt_box)
  76.  
  77. print(pck)
Add Comment
Please, Sign In to add comment