Advertisement
Guest User

Untitled

a guest
Jan 27th, 2020
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.34 KB | None | 0 0
  1. import os
  2. import argparse
  3. import yaml
  4. import json
  5. from utils.dataset import register_box_dataset
  6. from utils.evaluator import Evaluator, create_dirs
  7. from locator_predictor import LocatorPredictor
  8.  
  9.  
  10. class EndToEndEvaluator(Evaluator):
  11.  
  12.     ERROR_RESULT_CODES = ['ERROR', 'TIMEOUT', 'ONLY_OCCLUDED', 'EMPTY_BBOX']
  13.  
  14.     def __init__(self, dataset_dir, results_dir, pholocator_path, phoxicontrol_path, plcf_path, plcf_settings_json):
  15.         super().__init__(
  16.             results_dir,
  17.             os.path.basename(dataset_dir),
  18.             "locator_predictor",
  19.             step="final",  # "final" or string consisting of 7 decimals
  20.             config=None,  # config
  21.             weights=None,  # weights
  22.             score_threshold=0.5,
  23.             iou_type="segm",  # 'bbox' or 'segm'
  24.             iou_threshold=0.8,
  25.             low_bad_pick_iou=0.1,
  26.             high_bad_pick_iou=0.8)
  27.  
  28.         self.pholocator_path = pholocator_path
  29.         self.phoxicontrol_path = phoxicontrol_path
  30.         self.plcf_path = plcf_path
  31.         self.plcf_settings_json = plcf_settings_json
  32.         self.error_dt_json = os.path.join(self.step_dir, "error_results.json")
  33.  
  34.     def gts_by_filename(self):
  35.         result = {}
  36.         category_id = self.cocogt.getCatIds()[0]
  37.         for img_dict in self.cocogt.loadImgs(self.cocogt.getImgIds()):
  38.             img_dict["category_id"] = category_id
  39.             result[os.path.splitext(img_dict['file_name'])[0]] = img_dict
  40.         return result
  41.  
  42.     @staticmethod
  43.     def load_yaml(yaml_path):
  44.         with open(yaml_path) as f:
  45.             data = yaml.load(f, Loader=yaml.FullLoader)
  46.             return data
  47.  
  48.     @staticmethod
  49.     def get_segmentation(yaml_data, width, height):
  50.         segmentations = []
  51.         bboxes = []
  52.         if "poses" in yaml_data.keys():
  53.             for pose in yaml_data["poses"]:
  54.                 segmentation = []
  55.                 bbox = None
  56.                 for box_corner in pose["box_corners"]:
  57.                     x = box_corner['x'] * int(width)
  58.                     y = box_corner['y'] * int(height)
  59.                     segmentation.append(x)
  60.                     segmentation.append(y)
  61.                     if bbox is None:
  62.                         bbox = [x, y, x, y]
  63.                     bbox = [min(x, bbox[0]), min(y, bbox[1]), max(x, bbox[2]), max(y, bbox[3])]
  64.  
  65.                 bbox = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]
  66.                 segmentations.append(segmentation)
  67.                 bboxes.append(bbox)
  68.  
  69.         return bboxes, segmentations
  70.  
  71.     def create_dt_json(self, yaml_dir):
  72.         gts_by_filename = self.gts_by_filename()
  73.         yaml_files = sorted([f for f in os.listdir(yaml_dir) if os.path.splitext(f)[1] == ".yaml"])
  74.  
  75.         results_json_data = []
  76.         error_results_json_data = []
  77.  
  78.         for i, yaml_file in enumerate(yaml_files):
  79.             filename = os.path.splitext(yaml_file)[0]
  80.             yaml_data = self.load_yaml(os.path.join(yaml_dir, yaml_file))
  81.             result_code = yaml_data["result_code"]
  82.  
  83.             gt = gts_by_filename.get(filename)
  84.             if result_code == 'SUCCESS':
  85.                 if gt is None:
  86.                     print("Corresponding ground truth is missing for {} ".format(yaml_file))
  87.                     continue
  88.                 bboxes, segmentations = self.get_segmentation(yaml_data, gt['width'], gt['height'])
  89.                 for bbox, segmentation in zip(bboxes, segmentations):
  90.                     results_json_data.append({
  91.                         "image_id": gt["id"],
  92.                         "category_id": gt["category_id"],
  93.                         "score": 1.0,
  94.                         "bbox": bbox,
  95.                         "segmentation": [segmentation],
  96.                         "metadata": {}
  97.                     })
  98.             elif result_code in self.ERROR_RESULT_CODES:
  99.                 error_results_json_data.append({
  100.                     "error_result_code": result_code,
  101.                     "file_name": "{}.png".format(filename)
  102.                 })
  103.  
  104.         with open(self.dt_json, 'w') as f:
  105.             json.dump(results_json_data, f)
  106.  
  107.         with open(self.error_dt_json, 'w') as f:
  108.             json.dump(error_results_json_data, f)
  109.  
  110.     def predict_dataset(self, ):
  111.         predictor = LocatorPredictor(self.metadata.image_root, self.result_dir, self.pholocator_path, self.phoxicontrol_path, self.plcf_path, self.plcf_settings_json)
  112.         #predictor.detect()
  113.  
  114.         #TODO:: prerobit
  115.         self.create_dt_json(predictor.result_yaml_dir)
  116.  
  117.     def create_dirs(self):
  118.         super().create_dirs()
  119.         create_dirs([os.path.join(self.export_dir, code.lower()) for code in self.ERROR_RESULT_CODES])
  120.  
  121.     def debug(self):
  122.         super().debug()
  123.         with open(self.error_dt_json, 'r') as f:
  124.             error_dt = json.load(f)
  125.             for el in error_dt:
  126.                 filename = el["file_name"]
  127.                 code = el["error_result_code"]
  128.                 os.symlink(self.rel_img_source_path(filename), os.path.join(self.export_dir, code.lower(), filename))
  129.  
  130.  
  131. if __name__ == "__main__":
  132.     USER = os.environ['USER']
  133.     parser = argparse.ArgumentParser()
  134.  
  135.     # sat args
  136.     parser.add_argument("--dataset_dir", required=True,
  137.                         help="path to folder with praws, annotation and textures")
  138.  
  139.     parser.add_argument("--plcf_path", required=True,
  140.                         help="path to folder with plcfs")
  141.  
  142.     parser.add_argument("--plcf_settings_json", required=True,
  143.                         help="path to json with plcf setting, plcf's will be updated for this config")
  144.  
  145.     parser.add_argument("--results_dir", default='results/',
  146.                         help="results json and subfolders with plotted results")
  147.  
  148.     parser.add_argument("--pholocator_path", default='pholocator',
  149.                         help="name of pholocator proces")
  150.  
  151.     parser.add_argument("--phoxicontrol_path", default='PhoXiControl',
  152.                         help="name of phoxicontrol proces")
  153.  
  154.     args = parser.parse_args()
  155.  
  156.     register_box_dataset(args.dataset_dir)
  157.     evaluator = EndToEndEvaluator(args.dataset_dir, args.results_dir, args.pholocator_path, args.phoxicontrol_path, args.plcf_path, args.plcf_settings_json)
  158.     evaluator.predict_dataset()
  159.     evaluator.visualize()
  160.     evaluator.debug()
  161.     evaluator.report([0.75, 0.8, 0.85, 0.9])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement