Ditwoo

dummy nms

May 15th, 2021
484
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <iostream>
  2. #include <vector>
  3.  
  4. using std::vector;
  5.  
  6.  
  7. template <class T> void print(const vector<T> &items) {
  8.     std::cout << "[";
  9.     for (size_t i = 0; i < items.size(); ++i) {
  10.         std::cout << items[i];
  11.         if (i != items.size() - 1)
  12.             std::cout << " ";
  13.     }
  14.     std::cout << "]" << std::endl;
  15. }
  16.  
  17. struct BoundingBox {
  18.     int x1, y1, x2, y2;
  19.  
  20.     BoundingBox(int x1, int y1, int x2, int y2): x1(x1), y1(y1), x2(x2), y2(y2) {};
  21.     long long area() const { return ((long long) (this->x2 - this->x1 + 1)) * (this->y2 - this->y1 + 1); };
  22.     void print() const { std::cout << this->x1 << " " << this->y1 << " " << this->x2 << " " << this->y2 << std::endl; };
  23. };
  24.  
  25. vector<size_t> non_maximum_supression(const vector<BoundingBox> &bboxes, const vector<float> &scores, float iou_threshold = 0.5) {
  26.     assert(bboxes.size() == scores.size());
  27.  
  28.     vector<size_t> order(scores.size()), new_order;
  29.     for (size_t i = 0; i < scores.size(); ++i)
  30.         order[i] = i;
  31.  
  32.     // sort indices from higher score to lover
  33.     std::sort(order.begin(), order.end(), [&](int i, int j){return scores[i] >= scores[j];});
  34.  
  35.     vector<size_t> indices_to_keep = {};
  36.     size_t curr;
  37.     int x1, y1, x2, y2, w, h;
  38.     float intersection, overlap;
  39.     const BoundingBox *curr_bbox, *ith_bbox;
  40.  
  41.     while (!order.empty()) {
  42.         curr = order[0];
  43.         indices_to_keep.push_back(curr);
  44.         curr_bbox = &bboxes[curr];
  45.  
  46.         new_order.clear();
  47.         for (size_t i = 1; i < order.size(); ++i){
  48.             ith_bbox = &bboxes[order[i]];
  49.  
  50.             x1 = std::max(curr_bbox->x1, ith_bbox->x1);
  51.             y1 = std::max(curr_bbox->y1, ith_bbox->y1);
  52.             x2 = std::min(curr_bbox->x2, ith_bbox->x2);
  53.             y2 = std::min(curr_bbox->y2, ith_bbox->y2);
  54.  
  55.             w = std::max(0, x2 - x1 + 1);
  56.             h = std::max(0, y2 - y1 + 1);
  57.             intersection = float(w * h);
  58.             overlap = intersection / (float(curr_bbox->area()) + float(ith_bbox->area()) - intersection);
  59.  
  60.             // ignore bounding boxes with high iou (they are marked as similar)
  61.             if (overlap <= iou_threshold) {
  62.                 new_order.push_back(order[i]);
  63.             }
  64.         }
  65.         order.clear();
  66.         order = new_order;
  67.     }
  68.  
  69.     return indices_to_keep;
  70. }
  71.  
  72.  
  73. int main() {
  74.     vector<BoundingBox> bboxes = {
  75.             BoundingBox(0, 0, 10, 10),
  76.             BoundingBox(2, 2, 12, 12),
  77.             BoundingBox(9, 9, 19, 19),
  78.             BoundingBox(11, 11, 21, 21),
  79.             BoundingBox(25, 25, 35, 35)
  80.     };
  81.     vector<float> scores = {0.5, 0.7, 0.6, 0.3, 0.5};
  82.     for (const auto &item: bboxes)
  83.         item.print();
  84.     print(scores);
  85.     for (int i = 0; i < 100; ++i)
  86.         std::cout << "-";
  87.     std::cout << std::endl;
  88.  
  89.     auto res = non_maximum_supression(bboxes, scores, 0.5);
  90.     print(res);
  91.     return 0;
  92. }
  93.  
RAW Paste Data