Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package s3;
- /*************************************************************************
- *************************************************************************/
- import edu.princeton.cs.algs4.In;
- import edu.princeton.cs.algs4.Out;
- import edu.princeton.cs.algs4.Point2D;
- import edu.princeton.cs.algs4.RectHV;
- public class KdTree {
- private Node root;
- private static class Node {
- private Point2D p; // the point
- private RectHV rect; // the axis-aligned rectangle corresponding to this node
- private Node lb; // the left/bottom subtree
- private Node rt; // the right/top subtree
- public Node(Point2D point, RectHV rect){
- this.p = point;
- this.rect = rect;
- }
- }
- // construct an empty set of points
- public KdTree() {
- }
- // is the set empty?
- public boolean isEmpty() {
- return root == null;
- }
- // number of points in the set
- public int size() {
- return size(root);
- }
- private int size(Node node) {
- if (node == null)
- return 0;
- else
- return 1 + size(node.lb) + size(node.rt);
- }
- // add the point p to the set (if it is not already in the set)
- public void insert(Point2D p) {
- root = insertAt(root, p, true, new RectHV(0.0, 0.0, 1.0, 1.0));
- }
- private Node insertAt(Node node, Point2D p, boolean isX, RectHV rect) {
- if (node == null){
- return new Node(p, rect);
- }
- double dist = Double.MAX_VALUE;
- if(isX)
- dist = node.p.x() - p.x();
- else
- dist = node.p.y() - p.y();
- if (dist > 0)
- node.lb = insertAt(node.lb, p, !isX, shRect(node.p, rect, isX, true));
- else if (dist < 0)
- node.rt = insertAt(node.rt, p, !isX, shRect(node.p, rect, isX, false));
- else if(dist == 0 && !node.p.equals(p)){
- if(isX)
- dist = node.p.y() - p.y();
- else
- dist = node.p.x() - p.x();
- if (dist > 0)
- node.lb = insertAt(node.lb, p, !isX, shRect(p, rect, isX, true));
- else if (dist < 0)
- node.rt = insertAt(node.rt, p, !isX, shRect(p, rect, isX, false));
- }
- else
- node.p = p;
- return node;
- }
- private RectHV shRect(Point2D p, RectHV rect, boolean isX, boolean isLB) {
- if (isLB){
- if(isX){
- return new RectHV(rect.xmin(), rect.ymin(), p.x(), rect.ymax());
- }else{
- return new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), p.y());
- }
- } else {
- if(isX){
- return new RectHV(p.x(), rect.ymin(), rect.xmax(), rect.ymax());
- }else{
- return new RectHV(rect.xmin(), p.y(), rect.xmax(), rect.ymax());
- }
- }
- }
- // does the set contain the point p?
- public boolean contains(Point2D p) {
- return containsAt(root, p, true);
- }
- private boolean containsAt(Node node, Point2D p, boolean isX) {
- if (node == null)
- return false;
- double dist;
- if (isX)
- dist = node.p.x() - p.x();
- else
- dist = node.p.y() - p.y();
- if (node.p.equals(p))
- return true;
- else if (dist > 0)
- return containsAt(node.lb, p, !isX);
- else if (dist < 0)
- return containsAt(node.rt, p, !isX);
- else
- {
- if(isX)
- dist = node.p.y() - p.y();
- else
- dist = node.p.x() - p.x();
- if (dist > 0)
- return containsAt(node.lb, p, !isX);
- else if (dist < 0)
- return containsAt(node.rt, p, !isX);
- }
- return false;
- }
- // draw all of the points to standard draw
- public void draw() {
- }
- // all points in the set that are inside the rectangle
- public Iterable<Point2D> range(RectHV rect) {
- return null;
- }
- // a nearest neighbour in the set to p; null if set is empty
- public Point2D nearest(Point2D p) {
- if (isEmpty())
- return null;
- return nearest(root, p, root.p);
- }
- private Point2D nearest(Node node, Point2D p, Point2D npoint) {
- if (node == null)
- return npoint;
- if (node.rect.distanceSquaredTo(p) <= npoint.distanceTo(p)){
- if (node.p.distanceTo(p) < npoint.distanceTo(p))
- npoint = node.p;
- Point2D lb = nearest(node.lb, p, npoint);
- Point2D rt = nearest(node.lb, p, npoint);
- if ((lb.distanceTo(p) < rt.distanceTo(p)) && (lb.distanceTo(p) < npoint.distanceTo(p)))
- return lb;
- else if ((lb.distanceTo(p) > rt.distanceTo(p)) && (rt.distanceTo(p) < npoint.distanceTo(p)))
- return rt;
- else
- return npoint;
- }
- else
- return npoint;
- /*
- double dist = Double.MAX_VALUE;
- if (node.p.distanceTo(p) < npoint.distanceTo(p))
- npoint = node.p;
- if(isX)
- dist = node.p.x() - p.x();
- else
- dist = node.p.y() - p.y();
- if (node.p.equals(p))
- return node.p;
- if ((node.lb != null) && (node.rt != null)){
- if (dist > 0){
- return nearest(node.lb, p, !isX, npoint);
- }
- else if(dist < 0)
- {
- return nearest(node.rt, p, !isX, npoint);
- }
- else
- {
- if(isX)
- dist = node.p.y() - p.y();
- else
- dist = node.p.x() - p.x();
- if (dist > 0)
- return nearest(node.lb, p, !isX, npoint);
- else
- return nearest(node.rt, p, !isX, npoint);
- }
- }
- else if ((node.lb != null) && (node.rt == null)){
- if (dist > 0)
- return nearest(node.lb, p, !isX, npoint);
- else
- return npoint;
- }
- else if ((node.lb == null) && (node.rt != null)){
- if (dist > 0)
- return npoint;
- else
- return nearest(node.rt, p, !isX, npoint);
- }
- else
- return npoint;
- */
- }
- public static void main(String[] args) {
- In in = new In();
- Out out = new Out();
- int N = in.readInt(), C = in.readInt(), T = 50;
- Point2D[] queries = new Point2D[C];
- KdTree tree = new KdTree();
- out.printf("Inserting %d points into tree\n", N);
- for (int i = 0; i < N; i++) {
- tree.insert(new Point2D(in.readDouble(), in.readDouble()));
- }
- out.printf("tree.size(): %d\n", tree.size());
- out.printf("Testing `nearest` method, querying %d points\n", C);
- for (int i = 0; i < C; i++) {
- queries[i] = new Point2D(in.readDouble(), in.readDouble());
- out.printf("%s: %s\n", queries[i], tree.nearest(queries[i]));
- }
- for (int i = 0; i < T; i++) {
- for (int j = 0; j < C; j++) {
- tree.nearest(queries[j]);
- }
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement