Princeton Algorithm Kd-Trees

本次課程作業是編寫一個 2D 樹的數據結構,以表示單位正方形中的一組點,並支持高效的範圍搜索(查找查詢矩形中包含的所有點),以及高效的最近鄰居搜索(找到最接近查詢點的點)。

2D 樹有許多應用,從對天文物體進行分類到計算機動畫,再到加速神經網絡,再到挖掘數據再到圖像檢索等。

首先要用暴力做法做一次,題目限定只能使用 SET 或者 java.util.TreeSet,這個就比較簡單了,只需要注意一下 corner case,然後注意參數不合法的時候拋異常。

2D 樹的搜索和插入的算法與 BST 的算法相似,但是在根結點處,我們使用 x 座標來判斷大小,如果要插入的點的 x 座標比在根結點的點小,向左移動,否則向右移動;然後在下一個級別,我們使用 y 座標來判斷大小,如果要插入的點的 y 座標比結點中的點小,則向左移動,否則向右移動;然後在下一級,繼續使用 x 座標,依此類推……

由此,我們可以得到下圖:

2D 樹插入示意

2D 樹相對於 BST 的主要優勢在於,它支持範圍搜索和最近鄰居搜索的高效實現。每個節點對應於單位正方形中與軸對齊的矩形,該矩形將其子樹中的所有點都包含在內。根結點對應整個單位正方形,根的左、右子元素對應於兩個矩形,該兩個矩形被根結點的 x 座標分開,以此類推……

由此,我們可以得到範圍搜索和最近鄰居搜索的思想思路。

進行範圍搜索時,從根結點開始,遞歸地搜索左右子樹,若查詢矩形不與該結點對應的矩形相交,那麼就不需要探索該節點及其子樹。子樹只有在可能包含查詢矩形中包含的點時才被搜索。

進行最近鄰居搜索時,從根結點開始,遞歸地搜索左右子樹,如果到目前爲止發現的最近點比查詢點與結點對應的矩形之間的距離更近,則不需要探索該結點及其子樹。也就是說,僅當一個結點可能包含一個比目前發現的最佳結點更接近的點時,才進行搜索。

這樣的剪枝規則,依賴於能否快速找到附近的點。因此,我們需要注意在遞歸代碼中,當有兩個可能的子樹的時候,總是選擇位於分隔線同一側的子樹作爲要探索的第一棵子樹的查詢點。這是因爲在探索第一棵子樹時發現的有可能是最近的點,將有利於探索第二棵子樹時剪枝。

這裏在實現的時候,遞歸先左先右當然都可以得到正確的結果,但是這裏必須調整遞歸的順序,才能達到剪枝的效果。

這是因爲,如果左孩子包含 p,由於矩形是越來越小的,所以若點在某個 node 的矩形內被包含,則該 node 的 p 離這個所求 p 的距離就可能越小。min 越小,那麼剪枝的效果就越明顯,因爲越來越多的就不需要再計算了。於是,應該始終優先去遞歸那個 contains§ 的方向(因爲有且只有可能要麼是 left 要麼還是 right)包含 p。

如果不進行剪枝,那麼就算你的代碼功底非常好,在規定時間內求得了正確解,沒有超時,也一樣不能通過測評:

 - student sequence of kd-tree nodes involved in calls to Point2D methods:
 A D F I G B C E J
 - reference sequence of kd-tree nodes involved in calls to Point2D methods:
 A D F I G B C
 - failed on trial 1 of 1000

具體剪枝的策略就是,如果左孩子包含了目標點,那麼就去左孩子,如果右孩子包含了目標點,那麼就去右孩子。有可能左右孩子都不包含目標點,那麼離誰近就去誰那。

// 先左先右當然都可以得到正確的結果,但是
// 這裏必須調整遞歸的順序,才能達到剪枝的效果
if (node.left != null && node.left.rect.contains(p)) {
	// 如果左孩子包含 p,由於矩形是越來越小的,所以若點在某個 node 的矩形內被包含,則該 node 的 p 離這個所求 p 的距離就可能越小
	// min 越小,那麼剪枝的效果就越明顯,因爲越來越多的就不需要再計算了
	// 於是,應該始終優先去遞歸那個 contains(p) 的方向(因爲有且只有可能要麼是 left 要麼還是 right)包含 p
	findNearest(p, node.left);
	findNearest(p, node.right);
} else if (node.right != null && node.right.rect.contains(p)) {
	// 如果右孩子包含就先去右邊
	findNearest(p, node.right);
	findNearest(p, node.left);
} else {
	// 也可能出現兩個都不包含的情況,那麼離誰近就先去誰那
	// 注意調用時 null 的問題要特別處理,可以設置爲無窮大
	double toLeft = node.left != null ? node.left.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY;
	double toRight = node.right != null ? node.right.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY;
	if (toLeft < toRight) {
		findNearest(p, node.left);
		findNearest(p, node.right);
	} else {
		findNearest(p, node.right);
		findNearest(p, node.left);
	}
}

爲了代碼實現的方便,二叉樹當然要用遞歸的寫法啦。

課程提供了若干可視化工具用於調試。

draw() 函數的正確性將會大幅度提高 debug 的效率,所以這個函數一定要寫的正確。

在可視化過程中,使用暴力法求解的答案會標註爲紅色,使用 KDTree 方法求解的會標註爲藍色。由於我們非常有信心,暴力法肯定是對的,所以可以用這個方法來檢驗 KdTree 的搜索是不是正確。

KdTree 的可視化
範圍搜索的可視化示意
如圖就表示搜索結果是錯誤的,因爲兩個結果不一樣

使用上也非常簡單:當檢驗區域搜索的時候,只需要用鼠標在上面畫一個矩形;當檢驗最近鄰居的時候,只需要將鼠標移動到想要搜索的那個點對應的位置上(也許這個點並沒有在圖中畫出)。

另一個難點是處理重疊的點。重疊點在統計個數的時候不能被重複計算,我簡單地開了一個 same 數組,但是可能沒有必要。

另外特別要注意每一個新增點的時候,它對應的 RectHV 的範圍一定要搞清楚,否則後面的事情沒法做。不過這個也簡單,只要把 draw() 寫了,然後點幾個點,根據畫出來的圖馬上就知道自己寫的對不對了。如果圖和自己預想的不一樣,那就肯定是寫錯了,這個是最容易 debug 的。

以下是完整代碼,該代碼通過 100% 的測試數據,得分 100 分。

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;

import java.util.ArrayList;
import java.util.TreeSet;

public class PointSET {

    private final TreeSet<Point2D> set;

    public PointSET() {
        set = new TreeSet<>();
    }

    public boolean isEmpty() {
        return set.isEmpty();
    }

    public int size() {
        return set.size();
    }

    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        if (!contains((p))) {
            set.add(p);
        }
    }

    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        return set.contains(p);
    }

    public void draw() {
        for (Point2D p : set) {
            p.draw();
        }
    }

    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }
        ArrayList<Point2D> list = new ArrayList<>();
        for (Point2D p : set) {
            if (rect.contains(p)) {
                list.add(p);
            }
        }
        return list;
    }

    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        Point2D ans = null;
        if (!isEmpty()) {
            double min = Double.POSITIVE_INFINITY;
            for (Point2D pp : set) {
                // Do not call 'distanceTo()' in this program; instead use 'distanceSquaredTo()'. [Performance]
                double d = pp.distanceSquaredTo(p);
                if (d < min) {
                    min = d;
                    ans = pp;
                }
            }
        }
        return ans;
    }

    public static void main(String[] args) {
        PointSET ps = new PointSET();
        Point2D p1 = new Point2D(1, 1);
        Point2D p2 = new Point2D(1, 2);
        Point2D p3 = new Point2D(2, 1);
        Point2D p4 = new Point2D(0, 0);
        ps.insert(p1);
        ps.insert(p2);
        ps.insert(p3);
        ps.insert(p4);
        System.out.println(ps.nearest(p4));
        for (Point2D p : ps.range(new RectHV(1, 1, 3, 3))) {
            System.out.println(p);
        }
    }

}
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

import java.util.ArrayList;

/**
 * @author jxtxzzw
 */
public class KdTree {
    private Node root;
    private int size;

    private static class Node {

        private final Point2D p;
        private final int level;
        private Node left;
        private Node right;
        private final RectHV rect;
        // 記錄重疊的點
        private final ArrayList<Point2D> same = new ArrayList<>();


        // 對根結點
        public Node(Point2D p) {
            // 根結點層數是 0,範圍是單位正方形
            this(p, 0, 0, 1, 0, 1);
        }

        public Node(Point2D p, int level, double xmin, double xmax, double ymin, double ymax) {
            this.p = p;
            this.level = level;
            rect = new RectHV(xmin, ymin, xmax, ymax);
        }

        public void addSame(Point2D point) {
            same.add(point);
        }

        public boolean hasSamePoint() {
            return !same.isEmpty();
        }
    }

    private Point2D currentNearest;
    private double min;

    public KdTree() {

    }

    public boolean isEmpty() {
        return size == 0;
    }

    public int size() {
        return size;
    }

    private int compare(Point2D p, Node n) {
        if (n.level % 2 == 0) {
            // 如果是偶數層,按 x 比較
            if (Double.compare(p.x(), n.p.x()) == 0) {
                return Double.compare(p.y(), n.p.y());
            } else {
                return Double.compare(p.x(), n.p.x());
            }
        } else {
            // 按 y 比較
            if (Double.compare(p.y(), n.p.y()) == 0) {
                return Double.compare(p.x(), n.p.x());
            } else {
                return Double.compare(p.y(), n.p.y());
            }
        }
    }

    private Node generateNode(Point2D p, Node parent) {
        int cmp = compare(p, parent);
        if (cmp < 0) {
            if (parent.level % 2 == 0) {
                // 偶數層,比較結果是小於,說明是加在左邊
                // 那麼它的 xmin, ymin, ymax 都和父結點一樣,xmax 設置爲父結點的 p.x()
                return new Node(p, parent.level + 1, parent.rect.xmin(), parent.p.x(), parent.rect.ymin(), parent.rect.ymax());
            } else {
                // 奇數層,加在下邊,那麼只需要修改 ymax
                return new Node(p, parent.level + 1, parent.rect.xmin(), parent.rect.xmax(), parent.rect.ymin(), parent.p.y());
            }
        } else {
            if (parent.level % 2 == 0) {
                // 偶數層,加在右邊,那麼只需要修改 xmin
                return new Node(p, parent.level + 1, parent.p.x(), parent.rect.xmax(), parent.rect.ymin(), parent.rect.ymax());

            } else {
                // 奇數層,比較結果是大於,說明是加在上邊,修改 ymin
                return new Node(p, parent.level + 1, parent.rect.xmin(), parent.rect.xmax(), parent.p.y(), parent.rect.ymax());

            }
        }
    }

    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        } else {
            if (root == null) {
                // 初始化根結點
                size++;
                root = new Node(p);
            } else {
                // 二叉樹,用遞歸的寫法去調用
                insert(p, root);
            }
        }
    }

    private void insert(Point2D p, Node node) {
        int cmp = compare(p, node);
        // 如果比較結果是小於,那麼就是要往左邊走,右邊同理
        if (cmp < 0) {
            // 走到頭了就新建,否則繼續走
            if (node.left == null) {
                size++;
                node.left = generateNode(p, node);
            } else {
                insert(p, node.left);
            }
        } else if (cmp > 0) {
            if (node.right == null) {
                size++;
                node.right = generateNode(p, node);
            } else {
                insert(p, node.right);
            }
        }
        // 重疊的點,size 不加 1
    }


    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        } else {
            if (root == null) {
                return false;
            } else {
                // 遞歸的寫法
                return contains(p, root);
            }
        }
    }

    private boolean contains(Point2D p, Node node) {
        if (node == null) {
            return false;
        } else if (p.equals(node.p)) {
            return true;
        } else {
            if (compare(p, node) < 0) {
                return contains(p, node.left);
            } else {
                return contains(p, node.right);
            }
        }
    }

    public void draw() {
        // 清空畫布
        StdDraw.clear();
        // 遞歸調用
        draw(root);
    }

    private void draw(Node node) {
        if (node != null) {
            // 點用黑色
            StdDraw.setPenColor(StdDraw.BLACK);
            // 畫點
            node.p.draw();
            // 根據是不是偶數設置紅色還是藍色
            if (node.level % 2 == 0) {
                StdDraw.setPenColor(StdDraw.RED);
                StdDraw.line(node.p.x(), node.rect.ymin(), node.p.x(), node.rect.ymax());
            } else {
                StdDraw.setPenColor(StdDraw.BLUE);
                StdDraw.line(node.rect.xmin(), node.p.y(), node.rect.xmax(), node.p.y());
            }
            // 遞歸畫
            draw(node.left);
            draw(node.right);
        }
    }

    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }
        if (isEmpty()) {
            return null;
        }
        // 遞歸調用
        return new ArrayList<>(range(rect, root));
    }

    private ArrayList<Point2D> range(RectHV rect, Node node) {
        ArrayList<Point2D> list = new ArrayList<>();
        // A subtree is searched only if it might contain a point contained in the query rectangle.
        if (node != null && rect.intersects(node.rect)) {
            // 遞歸地檢查左右孩子
            list.addAll(range(rect, node.left));
            list.addAll(range(rect, node.right));
            // 如果對當前點包含,則加入
            if (rect.contains(node.p)) {
                list.add(node.p);
                // 重疊點應該只被計算一次
            }
        }
        return list;
    }

    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        if (isEmpty()) {
            return null;
        }
        currentNearest = null;
        min = Double.POSITIVE_INFINITY;
        findNearest(p, root);
        return currentNearest;
    }

    private void findNearest(Point2D p, Node node) {
        if (node == null) {
            return;
        }
        // The square of the Euclidean distance between the point {@code p} and the closest point on this rectangle; 0 if the point is contained in this rectangle
        if (node.rect.distanceSquaredTo(p) <= min) {
            // Do not call 'distanceTo()' in this program; instead use 'distanceSquaredTo()'. [Performance]
            double d = node.p.distanceSquaredTo(p);
            if (d < min) {
                min = d;
                currentNearest = node.p;
            }
            // 先左先右當然都可以得到正確的結果,但是
            // 這裏必須調整遞歸的順序,才能達到剪枝的效果
            if (node.left != null && node.left.rect.contains(p)) {
                // 如果左孩子包含 p,由於矩形是越來越小的,所以若點在某個 node 的矩形內被包含,則該 node 的 p 離這個所求 p 的距離就可能越小
                // min 越小,那麼剪枝的效果就越明顯,因爲越來越多的就不需要再計算了
                // 於是,應該始終優先去遞歸那個 contains(p) 的方向(因爲有且只有可能要麼是 left 要麼還是 right)包含 p
                findNearest(p, node.left);
                findNearest(p, node.right);
            } else if (node.right != null && node.right.rect.contains(p)) {
                // 如果右孩子包含就先去右邊
                findNearest(p, node.right);
                findNearest(p, node.left);
            } else {
                // 也可能出現兩個都不包含的情況,那麼離誰近就先去誰那
                // 注意調用時 null 的問題要特別處理,可以設置爲無窮大
                double toLeft = node.left != null ? node.left.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY;
                double toRight = node.right != null ? node.right.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY;
                if (toLeft < toRight) {
                    findNearest(p, node.left);
                    findNearest(p, node.right);
                } else {
                    findNearest(p, node.right);
                    findNearest(p, node.left);
                }
            }
        }

    }

    public static void main(String[] args) {
        KdTree kd;
        kd = new KdTree();
        kd.insert(new Point2D(0.7, 0.2));
        kd.insert(new Point2D(0.5, 0.4));
        kd.insert(new Point2D(0.2, 0.3));
        kd.insert(new Point2D(0.4, 0.7));
        kd.insert(new Point2D(0.9, 0.6));
        assert kd.nearest(new Point2D(0.73, 0.36)).equals(new Point2D(0.7, 0.2));
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章