Guest User

Untitled

a guest
Feb 15th, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 16.39 KB | None | 0 0
  1. import os
  2. import sys
  3. import h5py
  4. import numpy as np
  5. from progressbar import ProgressBar
  6. from commons import check_mkdir
  7.  
  8. def load_gt_h5(fn):
  9. """ Output: pts B x N x 3 float32
  10. gt_mask B x K x N bool
  11. gt_mask_label B x K uint8
  12. gt_mask_valid B x K bool
  13. gt_mask_other B x N bool
  14. All the ground-truth masks are represented as 0/1 mask over the 10k point cloud.
  15. All the ground-truth masks are disjoint, complete and corresponding to unique semantic labels.
  16. Different test shapes have different numbers of ground-truth masks, they are at the top in array gt_mask indicated by gt_valid.
  17. """
  18. with h5py.File(fn, 'r') as fin:
  19. gt_mask = fin['gt_mask'][:]
  20. gt_mask_label = fin['gt_mask_label'][:]
  21. gt_mask_valid = fin['gt_mask_valid'][:]
  22. gt_mask_other = fin['gt_mask_other'][:]
  23. return gt_mask, gt_mask_label, gt_mask_valid, gt_mask_other
  24.  
  25. def load_pred_h5(fn):
  26. """ Output: mask B x K x N bool
  27. label B x K uint8
  28. valid B x K bool
  29. conf B x K float32
  30. We only evaluate on the part predictions with valid = True.
  31. We assume no pre-sorting according to confidence score.
  32. """
  33. with h5py.File(fn, 'r') as fin:
  34. mask = fin['mask'][:]
  35. label = fin['label'][:]
  36. valid = fin['valid'][:]
  37. conf = fin['conf'][:]
  38. return mask, label, valid, conf
  39.  
  40. def compute_ap(tp, fp, gt_npos, n_bins=100, plot_fn=None):
  41. assert len(tp) == len(fp), 'ERROR: the length of true_pos and false_pos is not the same!'
  42.  
  43. tp = np.cumsum(tp)
  44. fp = np.cumsum(fp)
  45.  
  46. rec = tp / gt_npos
  47. prec = tp / (fp + tp)
  48.  
  49. rec = np.insert(rec, 0, 0.0)
  50. prec = np.insert(prec, 0, 1.0)
  51.  
  52. ap = 0.
  53. delta = 1.0 / n_bins
  54.  
  55. out_rec = np.arange(0, 1 + delta, delta)
  56. out_prec = np.zeros((n_bins+1), dtype=np.float32)
  57.  
  58. for idx, t in enumerate(out_rec):
  59. prec1 = prec[rec >= t]
  60. if len(prec1) == 0:
  61. p = 0.
  62. else:
  63. p = max(prec1)
  64.  
  65. out_prec[idx] = p
  66. ap = ap + p / (n_bins + 1)
  67.  
  68. if plot_fn is not None:
  69. import matplotlib.pyplot as plt
  70. fig = plt.figure()
  71. plt.plot(out_rec, out_prec, 'b-')
  72. plt.title('PR-Curve (AP: %4.2f%%)' % (ap*100))
  73. plt.xlabel('Recall')
  74. plt.ylabel('Precision')
  75. plt.xlim([0, 1])
  76. plt.ylim([0, 1])
  77. fig.savefig(plot_fn)
  78. plt.close(fig)
  79.  
  80. return ap
  81.  
  82. def eval_per_class_ap(stat_fn, gt_dir, pred_dir, iou_threshold=0.5, plot_dir=None):
  83. """ Input: stat_fn contains all part ids and names
  84. gt_dir contains test-xx.h5
  85. pred_dir contains test-xx.h5
  86. Output: aps: Average Prediction Scores for each part category, evaluated on all test shapes
  87. mAP: mean AP
  88. """
  89. print('Evaluation Start.')
  90. print('Ground-truth Directory: %s' % gt_dir)
  91. print('Prediction Directory: %s' % pred_dir)
  92.  
  93. if plot_dir is not None:
  94. check_mkdir(plot_dir)
  95.  
  96. # read stat_fn
  97. with open(stat_fn, 'r') as fin:
  98. part_name_list = [item.rstrip().split()[1] for item in fin.readlines()]
  99. print('Part Name List: ', part_name_list)
  100. n_labels = len(part_name_list)
  101. print('Total Number of Semantic Labels: %d' % n_labels)
  102.  
  103. # check all h5 files
  104. test_h5_list = []
  105. for item in os.listdir(gt_dir):
  106. if item.startswith('test-') and item.endswith('.h5'):
  107. if not os.path.exists(os.path.join(pred_dir, item)):
  108. print('ERROR: h5 file %s is in gt directory but not in pred directory.')
  109. exit(1)
  110. test_h5_list.append(item)
  111.  
  112. # read each h5 file and collect per-part-category true_pos, false_pos and confidence scores
  113. true_pos_list = [[] for item in part_name_list]
  114. false_pos_list = [[] for item in part_name_list]
  115. conf_score_list = [[] for item in part_name_list]
  116.  
  117. gt_npos = np.zeros((n_labels), dtype=np.int32)
  118.  
  119. for item in test_h5_list:
  120. print('Testing %s' % item)
  121.  
  122. gt_mask, gt_mask_label, gt_mask_valid, gt_mask_other = load_gt_h5(os.path.join(gt_dir, item))
  123. pred_mask, pred_label, pred_valid, pred_conf = load_pred_h5(os.path.join(pred_dir, item))
  124.  
  125. n_shape = gt_mask.shape[0]
  126. gt_n_ins = gt_mask.shape[1]
  127. pred_n_ins = pred_mask.shape[1]
  128.  
  129. for i in range(n_shape):
  130. cur_pred_mask = pred_mask[i, ...]
  131. cur_pred_label = pred_label[i, :]
  132. cur_pred_conf = pred_conf[i, :]
  133. cur_pred_valid = pred_valid[i, :]
  134.  
  135. cur_gt_mask = gt_mask[i, ...]
  136. cur_gt_label = gt_mask_label[i, :]
  137. cur_gt_valid = gt_mask_valid[i, :]
  138. cur_gt_other = gt_mask_other[i, :]
  139.  
  140. # classify all valid gt masks by part categories
  141. gt_mask_per_cat = [[] for item in part_name_list]
  142. for j in range(gt_n_ins):
  143. if cur_gt_valid[j]:
  144. sem_id = cur_gt_label[j]
  145. gt_mask_per_cat[sem_id].append(j)
  146. gt_npos[sem_id] += 1
  147.  
  148. # sort prediction and match iou to gt masks
  149. cur_pred_conf[~cur_pred_valid] = 0.0
  150. order = np.argsort(-cur_pred_conf)
  151.  
  152. gt_used = np.zeros((gt_n_ins), dtype=np.bool)
  153.  
  154. for j in range(pred_n_ins):
  155. idx = order[j]
  156. if cur_pred_valid[idx]:
  157. sem_id = cur_pred_label[idx]
  158.  
  159. iou_max = 0.0; cor_gt_id = -1;
  160. for k in gt_mask_per_cat[sem_id]:
  161. if not gt_used[k]:
  162. # Remove points with gt label *other* from the prediction
  163. # We will not evaluate them in the IoU since they can be assigned any label
  164. clean_cur_pred_mask = (cur_pred_mask[idx, :] & (~cur_gt_other))
  165.  
  166. intersect = np.sum(cur_gt_mask[k, :] & clean_cur_pred_mask)
  167. union = np.sum(cur_gt_mask[k, :] | clean_cur_pred_mask)
  168. iou = intersect * 1.0 / union
  169.  
  170. if iou > iou_max:
  171. iou_max = iou
  172. cor_gt_id = k
  173.  
  174. if iou_max > iou_threshold:
  175. gt_used[cor_gt_id] = True
  176.  
  177. # add in a true positive
  178. true_pos_list[sem_id].append(True)
  179. false_pos_list[sem_id].append(False)
  180. conf_score_list[sem_id].append(cur_pred_conf[idx])
  181. else:
  182. # add in a false positive
  183. true_pos_list[sem_id].append(False)
  184. false_pos_list[sem_id].append(True)
  185. conf_score_list[sem_id].append(cur_pred_conf[idx])
  186.  
  187. # compute per-part-category AP
  188. aps = np.zeros((n_labels), dtype=np.float32)
  189. ap_valids = np.ones((n_labels), dtype=np.bool)
  190. for i in range(n_labels):
  191. has_pred = (len(true_pos_list[i]) > 0)
  192. has_gt = (gt_npos[i] > 0)
  193.  
  194. if not has_gt:
  195. ap_valids[i] = False
  196. continue
  197.  
  198. if has_gt and not has_pred:
  199. continue
  200.  
  201. cur_true_pos = np.array(true_pos_list[i], dtype=np.float32)
  202. cur_false_pos = np.array(false_pos_list[i], dtype=np.float32)
  203. cur_conf_score = np.array(conf_score_list[i], dtype=np.float32)
  204.  
  205. # sort according to confidence score again
  206. order = np.argsort(-cur_conf_score)
  207. sorted_true_pos = cur_true_pos[order]
  208. sorted_false_pos = cur_false_pos[order]
  209.  
  210. out_plot_fn = None
  211. if plot_dir is not None:
  212. out_plot_fn = os.path.join(plot_dir, part_name_list[i].replace('/', '-')+'.png')
  213.  
  214. aps[i] = compute_ap(sorted_true_pos, sorted_false_pos, gt_npos[i], plot_fn=out_plot_fn)
  215.  
  216. # compute mean AP
  217. mean_ap = np.sum(aps * ap_valids) / np.sum(ap_valids)
  218.  
  219. return aps, ap_valids, gt_npos, mean_ap
  220.  
  221. def eval_per_shape_mean_ap(stat_fn, gt_dir, pred_dir, iou_threshold=0.5):
  222. """ Input: stat_fn contains all part ids and names
  223. gt_dir contains test-xx.h5
  224. pred_dir contains test-xx.h5
  225. Output: mean_aps: per-shape mean aps, which is the mean AP on each test shape,
  226. for each shape, we only consider the parts that exist in either gt or pred
  227. shape_valids: If a shape has valid parts to evaluate or not
  228. mean_mean_ap: mean per-shape mean aps
  229. """
  230. print('Evaluation Start.')
  231. print('Ground-truth Directory: %s' % gt_dir)
  232. print('Prediction Directory: %s' % pred_dir)
  233.  
  234. # read stat_fn
  235. with open(stat_fn, 'r') as fin:
  236. part_name_list = [item.rstrip().split()[1] for item in fin.readlines()]
  237. print('Part Name List: ', part_name_list)
  238. n_labels = len(part_name_list)
  239. print('Total Number of Semantic Labels: %d' % n_labels)
  240.  
  241. # check all h5 files
  242. test_h5_list = []
  243. for item in os.listdir(gt_dir):
  244. if item.startswith('test-') and item.endswith('.h5'):
  245. if not os.path.exists(os.path.join(pred_dir, item)):
  246. print('ERROR: h5 file %s is in gt directory but not in pred directory.')
  247. exit(1)
  248. test_h5_list.append(item)
  249.  
  250. mean_aps = []
  251. shape_valids = []
  252.  
  253. # read each h5 file
  254. for item in test_h5_list:
  255. print('Testing %s' % item)
  256.  
  257. gt_mask, gt_mask_label, gt_mask_valid, gt_mask_other = load_gt_h5(os.path.join(gt_dir, item))
  258. pred_mask, pred_label, pred_valid, pred_conf = load_pred_h5(os.path.join(pred_dir, item))
  259.  
  260. n_shape = gt_mask.shape[0]
  261. gt_n_ins = gt_mask.shape[1]
  262. pred_n_ins = pred_mask.shape[1]
  263.  
  264. for i in range(n_shape):
  265. cur_pred_mask = pred_mask[i, ...]
  266. cur_pred_label = pred_label[i, :]
  267. cur_pred_conf = pred_conf[i, :]
  268. cur_pred_valid = pred_valid[i, :]
  269.  
  270. cur_gt_mask = gt_mask[i, ...]
  271. cur_gt_label = gt_mask_label[i, :]
  272. cur_gt_valid = gt_mask_valid[i, :]
  273. cur_gt_other = gt_mask_other[i, :]
  274.  
  275. # per-shape evaluation
  276. true_pos_list = [[] for item in part_name_list]
  277. false_pos_list = [[] for item in part_name_list]
  278. gt_npos = np.zeros((n_labels), dtype=np.int32)
  279.  
  280. # classify all valid gt masks by part categories
  281. gt_mask_per_cat = [[] for item in part_name_list]
  282. for j in range(gt_n_ins):
  283. if cur_gt_valid[j]:
  284. sem_id = cur_gt_label[j]
  285. gt_mask_per_cat[sem_id].append(j)
  286. gt_npos[sem_id] += 1
  287.  
  288. # sort prediction and match iou to gt masks
  289. cur_pred_conf[~cur_pred_valid] = 0.0
  290. order = np.argsort(-cur_pred_conf)
  291.  
  292. gt_used = np.zeros((gt_n_ins), dtype=np.bool)
  293.  
  294. # enumerate all pred parts
  295. for j in range(pred_n_ins):
  296. idx = order[j]
  297. if cur_pred_valid[idx]:
  298. sem_id = cur_pred_label[idx]
  299.  
  300. iou_max = 0.0; cor_gt_id = -1;
  301. for k in gt_mask_per_cat[sem_id]:
  302. if not gt_used[k]:
  303. # Remove points with gt label *other* from the prediction
  304. # We will not evaluate them in the IoU since they can be assigned any label
  305. clean_cur_pred_mask = (cur_pred_mask[idx, :] & (~cur_gt_other))
  306.  
  307. intersect = np.sum(cur_gt_mask[k, :] & clean_cur_pred_mask)
  308. union = np.sum(cur_gt_mask[k, :] | clean_cur_pred_mask)
  309. iou = intersect * 1.0 / union
  310.  
  311. if iou > iou_max:
  312. iou_max = iou
  313. cor_gt_id = k
  314.  
  315. if iou_max > iou_threshold:
  316. gt_used[cor_gt_id] = True
  317.  
  318. # add in a true positive
  319. true_pos_list[sem_id].append(True)
  320. false_pos_list[sem_id].append(False)
  321. else:
  322. # add in a false positive
  323. true_pos_list[sem_id].append(False)
  324. false_pos_list[sem_id].append(True)
  325.  
  326. # evaluate per-part-category AP for the shape
  327. aps = np.zeros((n_labels), dtype=np.float32)
  328. ap_valids = np.zeros((n_labels), dtype=np.bool)
  329. for j in range(n_labels):
  330. has_pred = (len(true_pos_list[j]) > 0)
  331. has_gt = (gt_npos[j] > 0)
  332.  
  333. if has_pred and has_gt:
  334. cur_true_pos = np.array(true_pos_list[j], dtype=np.float32)
  335. cur_false_pos = np.array(false_pos_list[j], dtype=np.float32)
  336. aps[j] = compute_ap(cur_true_pos, cur_false_pos, gt_npos[j])
  337. ap_valids[j] = True
  338. elif has_pred and not has_gt:
  339. aps[j] = 0.0
  340. ap_valids[j] = True
  341. elif not has_pred and has_gt:
  342. aps[j] = 0.0
  343. ap_valids[j] = True
  344.  
  345. # compute mean AP for the current shape
  346. if np.sum(ap_valids) > 0:
  347. mean_aps.append(np.sum(aps * ap_valids) / np.sum(ap_valids))
  348. shape_valids.append(True)
  349. else:
  350. mean_aps.append(0.0)
  351. shape_valids.append(False)
  352.  
  353. # compute mean mean AP
  354. mean_aps = np.array(mean_aps, dtype=np.float32)
  355. shape_valids = np.array(shape_valids, dtype=np.bool)
  356.  
  357. mean_mean_ap = np.sum(mean_aps * shape_valids) / np.sum(shape_valids)
  358.  
  359. return mean_aps, shape_valids, mean_mean_ap
  360.  
  361.  
  362. def eval_hier_mean_iou(gt_labels, pred_labels, tree_constraint, tree_parents):
  363. """
  364. Input:
  365. gt_labels B x N x (C+1), boolean
  366. pred_logits B x N x (C+1), boolean
  367. tree_constraint T x (C+1), boolean
  368. tree_parents T, int32
  369. Output:
  370. mean_iou Scalar, float32
  371. part_iou C, float32
  372. """
  373. assert gt_labels.shape[0] == pred_labels.shape[0], 'ERROR: gt and pred have different num_shape'
  374. assert gt_labels.shape[1] == pred_labels.shape[1], 'ERROR: gt and pred have different num_point'
  375. assert gt_labels.shape[2] == pred_labels.shape[2], 'ERROR: gt and pred have different num_class+1'
  376.  
  377. num_shape = gt_labels.shape[0]
  378. num_point = gt_labels.shape[1]
  379. num_class = gt_labels.shape[2] - 1
  380.  
  381. assert tree_constraint.shape[0] == tree_parents.shape[0], 'ERROR: tree_constraint and tree_parents have different num_constraint'
  382. assert tree_constraint.shape[1] == num_class + 1 , 'ERROR: tree_constraint.shape[1] != num_class + 1'
  383. assert len(tree_parents.shape) == 1, 'ERROR: tree_parents is not a 1-dim array'
  384.  
  385. # make a copy of the prediction
  386. pred_labels = np.array(pred_labels, dtype=np.bool)
  387.  
  388. num_constraint = tree_constraint.shape[0]
  389.  
  390. part_intersect = np.zeros((num_class+1), dtype=np.float32)
  391. part_union = np.zeros((num_class+1), dtype=np.float32)
  392.  
  393. part_intersect[1] = np.sum(pred_labels[:, :, 1] & gt_labels[:, :, 1])
  394. part_union[1] = np.sum(pred_labels[:, :, 1] | gt_labels[:, :, 1])
  395.  
  396. all_idx = np.arange(num_class+1)
  397. all_visited = np.zeros((num_class+1), dtype=np.bool)
  398.  
  399. all_visited[1] = True
  400. for i in range(num_constraint):
  401. cur_pid = tree_parents[i]
  402. if all_visited[cur_pid]:
  403. cur_cons = tree_constraint[i]
  404.  
  405. idx = all_idx[cur_cons]
  406. gt_other = ((np.sum(gt_labels[:, :, idx], axis=-1) == 0) & gt_labels[:, :, cur_pid])
  407.  
  408. for j in list(idx):
  409. pred_labels[:, :, j] = (pred_labels[:, :, cur_pid] & pred_labels[:, :, j] & (~gt_other))
  410. part_intersect[j] += np.sum(pred_labels[:, :, j] & gt_labels[:, :, j])
  411. part_union[j] += np.sum(pred_labels[:, :, j] | gt_labels[:, :, j])
  412. all_visited[j] = True
  413.  
  414. all_valid_part_ids = all_idx[all_visited]
  415. part_iou = np.divide(part_intersect[all_valid_part_ids], part_union[all_valid_part_ids])
  416. mean_iou = np.mean(part_iou)
  417.  
  418. return mean_iou, part_iou, part_intersect[all_valid_part_ids], part_union[all_valid_part_ids]
Add Comment
Please, Sign In to add comment