Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- using std::vector;
- template <class T> void print(const vector<T> &items) {
- std::cout << "[";
- for (size_t i = 0; i < items.size(); ++i) {
- std::cout << items[i];
- if (i != items.size() - 1)
- std::cout << " ";
- }
- std::cout << "]" << std::endl;
- }
- struct BoundingBox {
- int x1, y1, x2, y2;
- BoundingBox(int x1, int y1, int x2, int y2): x1(x1), y1(y1), x2(x2), y2(y2) {};
- long long area() const { return ((long long) (this->x2 - this->x1 + 1)) * (this->y2 - this->y1 + 1); };
- void print() const { std::cout << this->x1 << " " << this->y1 << " " << this->x2 << " " << this->y2 << std::endl; };
- };
- vector<size_t> non_maximum_supression(const vector<BoundingBox> &bboxes, const vector<float> &scores, float iou_threshold = 0.5) {
- assert(bboxes.size() == scores.size());
- vector<size_t> order(scores.size()), new_order;
- for (size_t i = 0; i < scores.size(); ++i)
- order[i] = i;
- // sort indices from higher score to lover
- std::sort(order.begin(), order.end(), [&](int i, int j){return scores[i] >= scores[j];});
- vector<size_t> indices_to_keep = {};
- size_t curr;
- int x1, y1, x2, y2, w, h;
- float intersection, overlap;
- const BoundingBox *curr_bbox, *ith_bbox;
- while (!order.empty()) {
- curr = order[0];
- indices_to_keep.push_back(curr);
- curr_bbox = &bboxes[curr];
- new_order.clear();
- for (size_t i = 1; i < order.size(); ++i){
- ith_bbox = &bboxes[order[i]];
- x1 = std::max(curr_bbox->x1, ith_bbox->x1);
- y1 = std::max(curr_bbox->y1, ith_bbox->y1);
- x2 = std::min(curr_bbox->x2, ith_bbox->x2);
- y2 = std::min(curr_bbox->y2, ith_bbox->y2);
- w = std::max(0, x2 - x1 + 1);
- h = std::max(0, y2 - y1 + 1);
- intersection = float(w * h);
- overlap = intersection / (float(curr_bbox->area()) + float(ith_bbox->area()) - intersection);
- // ignore bounding boxes with high iou (they are marked as similar)
- if (overlap <= iou_threshold) {
- new_order.push_back(order[i]);
- }
- }
- order.clear();
- order = new_order;
- }
- return indices_to_keep;
- }
- int main() {
- vector<BoundingBox> bboxes = {
- BoundingBox(0, 0, 10, 10),
- BoundingBox(2, 2, 12, 12),
- BoundingBox(9, 9, 19, 19),
- BoundingBox(11, 11, 21, 21),
- BoundingBox(25, 25, 35, 35)
- };
- vector<float> scores = {0.5, 0.7, 0.6, 0.3, 0.5};
- for (const auto &item: bboxes)
- item.print();
- print(scores);
- for (int i = 0; i < 100; ++i)
- std::cout << "-";
- std::cout << std::endl;
- auto res = non_maximum_supression(bboxes, scores, 0.5);
- print(res);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement