Advertisement
Guest User

Untitled

a guest
Apr 24th, 2018
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 10.19 KB | None | 0 0
  1. #pragma once
  2.  
  3. #include <iostream>
  4. #include <cassert>
  5. #include <array>
  6. #include <vector>
  7. #include <cstdlib>
  8. #include <random>
  9. #include <algorithm>
  10. #include <chrono>
  11. #include <queue>
  12. #include <cmath>
  13. #include <thread>
  14. #include <fstream>
  15. #include <string>
  16. #include <limits>
  17.  
  18. #include "bounding_box.h"
  19. #include "timer.h"
  20. #include "random.h"
  21. #include "utils.h"
  22.  
  23. template<
  24.     class Point,
  25.     int ndim,
  26.     float (*GetKey)(Point &p, int axis),
  27.     float (*CalcDistance)(Point &a, Point &b)>
  28. class KDTree {
  29.     Point *points;
  30.     int length;
  31.     int bucketSize;
  32.     int threadDepth;
  33.  
  34. public:
  35.  
  36.     KDTree(Point *points, int length, int bucketSize) {
  37.         assert(ndim >= 1);
  38.         assert(length >= 0);
  39.         assert(bucketSize >= 1);
  40.  
  41.         this->points = points;
  42.         this->length = length;
  43.         this->bucketSize = bucketSize;
  44.         this->threadDepth = (int)std::log2(getCpuCoreCount()) + 1;
  45.     }
  46.  
  47.     void balance() {
  48.         TIMEIT("balance")
  49.         sort(0, length - 1, 0);
  50.     }
  51.  
  52.     std::vector<Point> collectInRadius(Point &origin, float radius) {
  53.         InRadiusCollector collector(origin, radius);
  54.         collectInRadius(0, length - 1, 0, collector);
  55.         return collector.getPoints();;
  56.     }
  57.  
  58.     std::vector<Point> collectKNearest(Point &origin, int k) {
  59.         KNearestCollector collector(origin, k);
  60.         collectKNearest(0, length - 1, 0, collector);
  61.         return collector.getPoints();
  62.     }
  63.  
  64.     std::vector<Point> collectInRadius_Naive(Point &origin, float radius) {
  65.         return collectNaive(InRadiusCollector(origin, radius));
  66.     }
  67.  
  68.     std::vector<Point> collectKNearest_Naive(Point &origin, int k) {
  69.         return collectNaive(KNearestCollector(origin, k));
  70.     }
  71.  
  72.     struct BoundingBoxWithDepth {
  73.         BoundingBox<ndim> box;
  74.         int depth;
  75.  
  76.         BoundingBoxWithDepth(BoundingBox<ndim> box, int depth)
  77.             : box(box), depth(depth) {}
  78.     };
  79.  
  80.     std::vector<BoundingBoxWithDepth> getBoundingBoxes(BoundingBox<ndim> outerBox) {
  81.         TIMEIT("getBoundingBoxes")
  82.         std::vector<BoundingBoxWithDepth> boxes;
  83.         insertBoundingBoxes(0, length - 1, 0, outerBox, boxes);
  84.         return boxes;
  85.     }
  86.  
  87.     void insertBoundingBoxes(int left, int right, int depth, BoundingBox<ndim> outerBox, std::vector<BoundingBoxWithDepth> &boxes) {
  88.         if (isLeaf(left, right)) return;
  89.  
  90.         int axis = getAxis(depth);
  91.         int medianIndex = getMedianIndex(left, right);
  92.         float splitPos = getValue(medianIndex, axis);
  93.  
  94.         boxes.push_back(BoundingBoxWithDepth(outerBox, depth));
  95.  
  96.         BoundingBox<ndim> leftBox = outerBox;
  97.         leftBox.max[axis] = splitPos;
  98.         insertBoundingBoxes(left, medianIndex - 1, depth + 1, leftBox, boxes);
  99.  
  100.         BoundingBox<ndim> rightBox = outerBox;
  101.         rightBox.min[axis] = splitPos;
  102.         insertBoundingBoxes(medianIndex + 1, right, depth + 1, rightBox, boxes);
  103.     }
  104.  
  105. private:
  106.  
  107.     /* Sort Point Array
  108.     *****************************************************/
  109.  
  110.     void sort(int left, int right, int depth) {
  111.         if (isLeaf(left, right)) return;
  112.  
  113.         int axis = getAxis(depth);
  114.         int medianIndex = fixateMedian(left, right, axis);
  115.  
  116.         if (depth < threadDepth) {
  117.             std::thread thread(&KDTree::sort, this, left, medianIndex - 1, depth + 1);
  118.             sort(medianIndex + 1, right, depth + 1);
  119.             thread.join();
  120.         } else {
  121.             sort(left, medianIndex - 1, depth + 1);
  122.             sort(medianIndex + 1, right, depth + 1);
  123.         }
  124.     }
  125.  
  126.     int fixateMedian(int left, int right, int axis) {
  127.         int medianIndex = getMedianIndex(left, right);
  128.         quickselect_Iterative(left, right, medianIndex, axis);
  129.         return medianIndex;
  130.     }
  131.  
  132.  
  133.     void quickselect_Iterative(int left, int right, int k, int axis) {
  134.         while (left < right) {
  135.             int pivotIndex = selectPivotIndex(left, right);
  136.             int split = partition(left, right, axis, pivotIndex);
  137.             if (k < split) {
  138.                 right = split - 1;
  139.             } else if (k > split) {
  140.                 left = split + 1;
  141.             } else {
  142.                 return;
  143.             }
  144.         }
  145.     }
  146.  
  147.     int partition(int left, int right, int axis, int pivotIndex) {
  148.         float pivotValue = getValue(pivotIndex, axis);
  149.  
  150.         // move pivot to the end
  151.         swap(pivotIndex, right);
  152.  
  153.         // swap values that are smaller than the pivot to the front
  154.         int index = left;
  155.         for (int i = left; i < right; i++) {
  156.             if (getValue(i, axis) < pivotValue) {
  157.                 swap(index, i);
  158.                 index++;
  159.             }
  160.         }
  161.  
  162.         // move pivot to correct position
  163.         swap(right, index);
  164.  
  165.         // try to move the split index closer to the median if possible
  166.         // this is important when there are many points on an axis aligned line
  167.         int medianIndex = getMedianIndex(left, right);
  168.         while (index < medianIndex && getValue(index + 1, axis) == pivotValue) {
  169.             index++;
  170.         }
  171.  
  172.         return index;
  173.     }
  174.  
  175.     inline int selectPivotIndex(int left, int right) {
  176.         return left + randomInt_Positive(left ^ right) % (right - left);
  177.     }
  178.  
  179.  
  180.  
  181.     /* Collect in Radius
  182.     *********************************************/
  183.  
  184.     struct InRadiusCollector {
  185.         Point origin;
  186.         float radius;
  187.         std::vector<Point> points;
  188.  
  189.         InRadiusCollector(Point origin, float radius) {
  190.             this->origin = origin;
  191.             this->radius = radius;
  192.         }
  193.  
  194.         void consider(Point &point) {
  195.             if (CalcDistance(origin, point) <= radius) {
  196.                 points.push_back(point);
  197.             }
  198.         }
  199.  
  200.         std::vector<Point> getPoints() {
  201.             return points;
  202.         }
  203.     };
  204.  
  205.     void collectInRadius(int left, int right, int depth, InRadiusCollector &collector) {
  206.         if (isLeaf(left, right)) {
  207.             considerPointsInBucket(left, right, collector);
  208.             return;
  209.         }
  210.  
  211.         int axis = getAxis(depth);
  212.         int medianIndex = getMedianIndex(left, right);
  213.  
  214.         Point &splitPoint = points[medianIndex];
  215.         collector.consider(splitPoint);
  216.  
  217.         float splitPos = GetKey(splitPoint, axis);
  218.         float originPos = GetKey(collector.origin, axis);
  219.  
  220.         if (originPos - collector.radius <= splitPos) {
  221.             collectInRadius(left, medianIndex - 1, depth + 1, collector);
  222.         }
  223.         if (originPos + collector.radius >= splitPos) {
  224.             collectInRadius(medianIndex + 1, right, depth + 1, collector);
  225.         }
  226.     }
  227.  
  228.  
  229.  
  230.     /* Collect k Nearest
  231.     *************************************************/
  232.  
  233.     struct PointWithDistance {
  234.         Point point;
  235.         float distance;
  236.  
  237.         PointWithDistance(Point point, float distance)
  238.             : point(point), distance(distance) {}
  239.  
  240.         friend bool operator<(const PointWithDistance &p1, const PointWithDistance &p2) {
  241.             return p1.distance < p2.distance;
  242.         }
  243.     };
  244.  
  245.     struct KNearestCollector {
  246.         unsigned int k;
  247.         Point origin;
  248.         float maxDistance;
  249.         std::priority_queue<PointWithDistance> queue;
  250.  
  251.         KNearestCollector(Point origin, int k) {
  252.             assert(k >= 0);
  253.             this->k = k;
  254.             this->origin = origin;
  255.             this->maxDistance = -1;
  256.         }
  257.  
  258.         void consider(Point &point) {
  259.             float distance = CalcDistance(origin, point);
  260.             if (queue.size() < k) {
  261.                 queue.push(PointWithDistance(point, distance));
  262.                 maxDistance = getCurrentMaxDistance();
  263.             } else if (distance < maxDistance) {
  264.                 queue.pop();
  265.                 queue.push(PointWithDistance(point, distance));
  266.                 maxDistance = getCurrentMaxDistance();
  267.             }
  268.         }
  269.  
  270.         float getCurrentMaxDistance() {
  271.             return queue.top().distance;
  272.         }
  273.  
  274.         std::vector<Point> getPoints() {
  275.             std::vector<Point> points;
  276.             while (!queue.empty()) {
  277.                 points.push_back(queue.top().point);
  278.                 queue.pop();
  279.             }
  280.             return points;
  281.         }
  282.     };
  283.  
  284.     void collectKNearest(int left, int right, int depth, KNearestCollector &collector) {
  285.         if (isLeaf(left, right)) {
  286.             considerPointsInBucket(left, right, collector);
  287.             return;
  288.         }
  289.  
  290.         int axis = getAxis(depth);
  291.         int medianIndex = getMedianIndex(left, right);
  292.  
  293.         Point &splitPoint = points[medianIndex];
  294.  
  295.         float splitPos = GetKey(splitPoint, axis);
  296.         float originPos = GetKey(collector.origin, axis);
  297.  
  298.         if (originPos <= splitPos) {
  299.             collectKNearest(left, medianIndex - 1, depth + 1, collector);
  300.             if (originPos + collector.maxDistance >= splitPos) {
  301.                 collectKNearest(medianIndex + 1, right, depth + 1, collector);
  302.             }
  303.         } else {
  304.             collectKNearest(medianIndex + 1, right, depth + 1, collector);
  305.             if (originPos - collector.maxDistance <= splitPos) {
  306.                 collectKNearest(left, medianIndex - 1, depth + 1, collector);
  307.             }
  308.         }
  309.  
  310.         collector.consider(splitPoint);
  311.     }
  312.  
  313.  
  314.     /* Utils
  315.     *************************************************/
  316.  
  317.     template<class Collector>
  318.     void considerPointsInBucket(int left, int right, Collector &collector) {
  319.         for (int i = left; i <= right; i++) {
  320.             collector.consider(points[i]);
  321.         }
  322.     }
  323.  
  324.     template<class Collector>
  325.     std::vector<Point> collectNaive(Collector collector) {
  326.         for (int i = 0; i < length; i++) {
  327.             collector.consider(points[i]);
  328.         }
  329.         return collector.getPoints();
  330.     }
  331.  
  332.     inline bool isLeaf(int left, int right) {
  333.         return left + bucketSize > right;
  334.     }
  335.  
  336.     inline int getAxis(int depth) {
  337.         return depth % ndim;
  338.     }
  339.  
  340.     inline void swap(int a, int b) {
  341.         Point tmp = points[a];
  342.         points[a] = points[b];
  343.         points[b] = tmp;
  344.     }
  345.  
  346.     inline float getValue(int index, int axis) {
  347.         return GetKey(points[index], axis);
  348.     }
  349.  
  350. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement