Advertisement
Guest User

kdTree nearest

a guest
Sep 29th, 2016
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 6.78 KB | None | 0 0
  1. package s3;
  2.  
  3.  
  4. import java.awt.Color;
  5.  
  6. /*************************************************************************
  7.  *************************************************************************/
  8.  
  9. import edu.princeton.cs.algs4.In;
  10. import edu.princeton.cs.algs4.Out;
  11. import edu.princeton.cs.algs4.Point2D;
  12. import edu.princeton.cs.algs4.RectHV;
  13. import edu.princeton.cs.algs4.StdDraw;
  14.  
  15. public class KdTree {
  16.     private Node root;
  17.    
  18.     private static class Node {
  19.         private Point2D p; // the point
  20.         private RectHV rect; // the axis-aligned rectangle corresponding to this node
  21.         private Node lb; // the left/bottom subtree
  22.         private Node rt; // the right/top subtree
  23.        
  24.         public Node(Point2D point){
  25.             this.p = point;
  26.         }
  27.     }
  28.     // construct an empty set of points
  29.     public KdTree() {
  30.     }
  31.  
  32.     // is the set empty?
  33.     public boolean isEmpty() {
  34.         return root == null;
  35.     }
  36.  
  37.     // number of points in the set
  38.     public int size() {
  39.         return size(root);
  40.     }
  41.  
  42.     private int size(Node node) {
  43.         if (node == null)
  44.             return 0;
  45.         else
  46.             return 1 + size(node.lb) + size(node.rt);
  47.     }
  48.  
  49.     // add the point p to the set (if it is not already in the set)
  50.     public void insert(Point2D p) {
  51.         root = insertAt(root, p, true);
  52.     }
  53.  
  54.     private Node insertAt(Node node, Point2D p, boolean isX) {
  55.         if (node == null)
  56.             return new Node(p);
  57.        
  58.         double dist = Double.MAX_VALUE;
  59.        
  60.         if (isX)
  61.             dist = node.p.x() - p.x();
  62.         else
  63.             dist = node.p.y() - p.y();
  64.        
  65.         if (dist > 0)
  66.             node.lb = insertAt(node.lb, p, !isX);
  67.         else if (dist < 0)
  68.             node.rt = insertAt(node.rt, p, !isX);
  69.         else if (node.p.equals(p))
  70.             node.p = p;
  71.         else{
  72.             dist = node.p.compareTo(p);
  73.             if (dist > 0)
  74.                 insertAt(node.lb, p, !isX);
  75.             else
  76.                 insertAt(node.rt, p, !isX);
  77.         }
  78.        
  79.         return node;
  80.     }
  81.  
  82.     // does the set contain the point p?
  83.     public boolean contains(Point2D p) {
  84.         return containsAt(root, p);
  85.     }
  86.  
  87.     private boolean containsAt(Node node, Point2D p) {
  88.         if (node == null)
  89.             return false;
  90.         if (node.p.equals(p))
  91.             return true;
  92.         else if (p.compareTo(node.p) < 0)
  93.             return containsAt(node.lb, p);
  94.         else
  95.             return containsAt(node.rt, p);
  96.     }
  97.  
  98.     // draw all of the points to standard draw
  99.     public void draw() {
  100.         if (isEmpty())
  101.             return;
  102.         else
  103.             drawlb(root, 0, new Point2D(0, 0));
  104.     }
  105.  
  106.     private void drawlb(Node node, int lvl, Point2D last) {
  107.         if (node == null)
  108.             return;
  109.        
  110.         if (lvl % 2 == 0){
  111.             StdDraw.setPenColor(Color.red);
  112.             StdDraw.filledCircle(node.p.x(), node.p.y(), 0.005);
  113.             StdDraw.line(node.p.x(), last.y(), node.p.x(), 0);         
  114.         }
  115.         else{
  116.             StdDraw.setPenColor(Color.blue);
  117.             StdDraw.filledCircle(node.p.x(), node.p.y(), 0.005);           
  118.             StdDraw.line(0, node.p.y(), last.x(), node.p.y());
  119.         }
  120.        
  121.         drawlb(node.lb, lvl + 1, node.p);
  122.         drawrt(node.rt, lvl + 1, node.p);
  123.     }
  124.  
  125.     private void drawrt(Node node, int lvl, Point2D last) {
  126.         if (node == null)
  127.             return;
  128.        
  129.         if (lvl % 2 == 0){
  130.             StdDraw.setPenColor(Color.red);
  131.             StdDraw.filledCircle(node.p.x(), node.p.y(), 0.005);
  132.             StdDraw.line(node.p.x(), 1, node.p.x(), last.y());         
  133.         }
  134.         else{
  135.             StdDraw.setPenColor(Color.blue);
  136.             StdDraw.filledCircle(node.p.x(), node.p.y(), 0.005);           
  137.             StdDraw.line(last.x(), node.p.y(), 1, node.p.y());
  138.         }
  139.        
  140.         drawlb(node.lb, lvl + 1, node.p);
  141.         drawrt(node.rt, lvl + 1, node.p);
  142.     }
  143.  
  144.     // all points in the set that are inside the rectangle
  145.     public Iterable<Point2D> range(RectHV rect) {
  146.         return null;
  147.     }
  148.  
  149.     // a nearest neighbor in the set to p; null if set is empty
  150.     public Point2D nearest(Point2D p) {
  151.         if (isEmpty())
  152.             return null;
  153.         return nearest(root, p, true, root.p);
  154.     }
  155.  
  156.     private Point2D nearest(Node node, Point2D p, boolean isX, Point2D npoint) {
  157.         double dist = Double.MAX_VALUE;
  158.         if (node.p.distanceTo(p) < npoint.distanceTo(p))
  159.             npoint = node.p;
  160.        
  161.         if(isX)
  162.             dist = node.p.x() - p.x();
  163.         else
  164.             dist = node.p.y() - p.y();
  165.        
  166.        
  167.         if (node.p.equals(p))
  168.             return node.p;
  169.        
  170.         if ((node.lb != null) && (node.rt != null)){
  171.             if (dist > 0){
  172.                     return nearest(node.lb, p, !isX, npoint);
  173.             }
  174.             else{
  175.                     return nearest(node.rt, p, !isX, npoint);
  176.             }      
  177.         }
  178.         else if ((node.lb != null) && (node.rt == null)){
  179.             if (dist > 0)
  180.                 return nearest(node.lb, p, !isX, npoint);
  181.             else
  182.                 return npoint;
  183.         }
  184.         else if ((node.lb == null) && (node.rt != null)){
  185.             if (dist > 0)
  186.                 return npoint;
  187.             else
  188.                 return nearest(node.rt, p, !isX, npoint);
  189.         }
  190.         else
  191.             return npoint;
  192.     }
  193.    
  194.     public static void main(String[] args) {
  195.         In in = new In();
  196.         Out out = new Out();    
  197.         int N = in.readInt(), C = in.readInt(), T = 50;
  198.         Point2D[] queries = new Point2D[C];
  199.         KdTree tree = new KdTree();
  200.         out.printf("Inserting %d points into tree\n", N);
  201.         for (int i = 0; i < N; i++) {
  202.             tree.insert(new Point2D(in.readDouble(), in.readDouble()));
  203.         }
  204.         out.printf("tree.size(): %d\n", tree.size());
  205.         out.printf("Testing `nearest` method, querying %d points\n", C);
  206.    
  207.         for (int i = 0; i < C; i++) {
  208.             queries[i] = new Point2D(in.readDouble(), in.readDouble());
  209.             out.printf("%s: %s\n", queries[i], tree.nearest(queries[i]));
  210.         }
  211.         for (int i = 0; i < T; i++) {
  212.             for (int j = 0; j < C; j++) {
  213.                 tree.nearest(queries[j]);
  214.             }
  215.         }
  216.     }
  217.  
  218.     /*
  219.     public static void main(String[] args) {
  220.         In in = new In();
  221.         Out out = new Out();
  222.         int N = in.readInt(), C = in.readInt(), T = 20;
  223.         KdTree tree = new KdTree();
  224.         Point2D [] points = new Point2D[C];
  225.         out.printf("Inserting %d points into tree\n", N);
  226.         for (int i = 0; i < N; i++) {
  227.             tree.insert(new Point2D(in.readDouble(), in.readDouble()));
  228.         }
  229.         out.printf("tree.size(): %d\n", tree.size());
  230.         out.printf("Testing contains method, querying %d points\n", C);
  231.         for (int i = 0; i < C; i++) {
  232.             points[i] = new Point2D(in.readDouble(), in.readDouble());
  233.             out.printf("%s: %s\n", points[i], tree.contains(points[i]));
  234.         }
  235.         for (int i = 0; i < T; i++) {
  236.             for (int j = 0; j < C; j++) {
  237.                 tree.contains(points[j]);
  238.             }
  239.         }
  240.     }*/
  241. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement