遍歷KD樹的一個嘗試,不用遞歸遍歷?也不使用隊列?

《情景劇》

看客:KD樹?不用遞歸遍歷?也不使用隊列?別逗了小子。
我: 您別不信。我就是這麼厲害,就是這麼自信。
看客:那你敢不敢放出來看看?
我: 那您瞧好嘞……

正文

在本文中,我要給大家介紹一個用Python迭代器實現的KD樹(二叉樹)遍歷的方法。迭代的介紹,可以參考我寫的這篇文章。
網址:http://blog.csdn.net/weixin_37722024/article/details/62424311
在這裏,我就不細說迭代器的寫法了,而只是對遍歷時需要用到的幾個成員變量和成員函數__next__() 中的遍歷的邏輯作出說明。
Python迭代器的幫助:https://docs.python.org/3/tutorial/classes.html#iterators

一、KD樹的類的構造

構造KD樹也就是二叉樹的方法比較簡單,就是首先構造根節點,然後用遞歸地方法構造根節點的左子樹和右子樹。這不在本文的討論範圍。

KD_Node 類中,定義了二叉樹必不可少的節點值、左子樹和右子樹,以及KD樹需要的每個節點的split 。除了這幾個基本成員以外,新增了三個成員變量:一個是所有類實例公用的cur_trav ,這個成員變量保存了一個類型爲KD_Node 的節點,用作搜索時的遊標;類實例的成員變量flag_trav ,是int 型,保存每一個節點的在遍歷過程中的狀態;類實例的成員變量father ,是KD_Node 型,保存了每個節點的父節點。

這樣的結構,每個節點增加一個int 變量,類增加一個遊標節點由所有節點(也是類的實例)共享,增加的內存開銷有限,卻帶來了極大的靈活性,也使得代碼變得簡潔。用同樣的遍歷過程,不止能獲得各個節點的值,還可以用matplotlib繪製出KD樹的空間分割圖,甚至模擬KNN的搜索過程。我在draw_KDT函數裏就想這麼做,還沒做完。

二、迭代過程

前面說過了搜索的邏輯都寫在__next__() 裏了,具體的流程還請看一下代碼。我在開始寫迭代器的時候是想用遞歸寫__next__() 來着,試過,不行,放棄,另起樓竈,就鼓搗出這麼個玩意兒。

其中,用作遍歷的幾個變量,在KD樹剛剛建立好後都保持初始值。在開始遍歷時,拷貝根節點到遊標cur_trav ,這就是當前搜索的節點。在搜索的過程中,根據遊標所在的當前節點的狀態flag_trav ,來確定是返回當前節點,或者返回他的左子節點,又或者返回右子節點。所以這個狀態只使用了最低3個bit,bit0代表當前節點是否已經遍歷過,bit1代表當前節點的左子節點是否遍歷過,bit2代表當前節點的右子節點是否遍歷過。而father 是爲了能夠返回父節點,隨後能夠搜索到另外一個子樹。

節點狀態flag_trav 的變化規律:1、初始值爲0,表示當前節點、左子和右子都未被搜索過;2、算是中序遍歷,也就是說先搜當前節點,然後左,最後右。所以[bit2,bit1,bit0]的狀態只有:[000], [001],[011],[111]四種狀態;3、如果子節點不存在,則跳過。左右子節點都不存在,就跳到當前節點的父節點。

三、代碼

直接把幾段關鍵的貼出來,懶得看長篇代碼的同學,就看着幾段意思意思吧。

代碼段1,類的聲明:

class KD_Node:
    cur_trav = None             # cursor for traversal.

    def __init__(self, point=None, split=None, L=None, R=None, father=None):
        """
        initiate a kd tree.
        point: datum of this node
        split: split plane for this node
        L:     left son
        R:     right son
        father: father of this node, if root it's None
        """
        self.point  = point
        self.split  = split
        self.left   = L
        self.right  = R
        self.father = father
        self.flag_trav = 0      # traversal flag. 
                                #   bit 0 is notation for itself
                                #   bit 1 is for its left son
                                #   bit 2 is for its right son
        ......

代碼段2,__next__() 函數

    def __next__(self):
        # with non-iteration traverse the tree
        cursor = None
        if KD_Node.cur_trav == None:        # First time to use cur_trav, initiate.
            KD_Node.cur_trav = self

        cursor = KD_Node.cur_trav
        while 1:
            if cursor.flag_trav & 0X07 == 0X7:      # any node has flag with
                                                    # value=3 
                                                    # that states a completion
                                                    # of traversal.
                if cursor.father == None:
                    raise StopIteration
                else:
                    cursor = cursor.father

            elif cursor.flag_trav & 0X01 == 0:      # if bit0 == 0,
                cursor.flag_trav |= 0X01            # set bit0 = 1
                #cursor = cursor            # not need. set cursor => self
                break                               # BREAK! return current.

            elif cursor.flag_trav & 0X02 == 0:      # if bit1==0, bit2==0
                cursor.flag_trav |= 0X02            # set bit1 of self
                if cursor.left != None:
                    cursor = cursor.left            # set cursor => left son
                else:                               # self.left is None, skip
                    continue

            elif cursor.flag_trav & 0X04 == 0:      # if bit2 == 0,
                cursor.flag_trav |= 0X04            # set bit2 = 1
                if cursor.right != None:
                    cursor = cursor.right           # set cursor => right son
                else:
                    continue
        KD_Node.cur_trav = cursor

        return KD_Node.cur_trav

代碼段3,簡潔的遍歷

def main():
    kd = None
    kd = CreateKDT(kd, X)

    for node in kd:
        print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )

完整代碼如下

import numpy as np
import matplotlib.pyplot as plt
"""
X,  feature vectors
Y,  class of X
X_with_class, I just use this to draw a graphic in the piece of code at \
    the bottom.
D,  dimension of each of vectors.
"""
# Construct initial to be classified data
X = np.array([ (3,5), (2,4), (1,1), (5,2), (1,5), (4,1) ])
Y = [ 'g', 'g', 'r', 'r', 'g', 'r' ]
X_with_class = [ [X[a,0],X[a,1],Y[a]] for a in range(len(X)) ]

D = 0
if len(X[0]):
    D = len(X[0])


class KD_Node:
    cur_trav = None             # cursor for traversal.

    def __init__(self, point=None, split=None, L=None, R=None, father=None):
        """
        initiate a kd tree.
        point: datum of this node
        split: split plane for this node
        L:     left son
        R:     right son
        father: father of this node, if root it's None
        """
        self.point  = point
        self.split  = split
        self.left   = L
        self.right  = R
        self.father = father
        self.flag_trav = 0      # traversal flag. 
                                #   bit 0 is notation for itself
                                #   bit 1 is for its left son
                                #   bit 2 is for its right son

    def clear_trav(self):
        KD_Node.cur_trav = None
        self.flag_trav = 0
        if self.left:
            self.left.clear_trav()
        if self.right:
            self.right.clear_trav()

    def __iter__(self):
        return self

    def __next__(self):
        # with non-iteration traverse the tree
        cursor = None
        if KD_Node.cur_trav == None:        # First time to use cur_trav, initiate.
            KD_Node.cur_trav = self

        cursor = KD_Node.cur_trav
        while 1:
            if cursor.flag_trav & 0X07 == 0X7:      # any node has flag with
                                                    # value=3 
                                                    # that states a completion
                                                    # of traversal.
                if cursor.father == None:
                    raise StopIteration
                else:
                    cursor = cursor.father

            elif cursor.flag_trav & 0X01 == 0:      # if bit0 == 0,
                cursor.flag_trav |= 0X01            # set bit0 = 1
                #cursor = cursor            # not need. set cursor => self
                break                               # BREAK! return current.

            elif cursor.flag_trav & 0X02 == 0:      # if bit1==0, bit2==0
                cursor.flag_trav |= 0X02            # set bit1 of self
                if cursor.left != None:
                    cursor = cursor.left            # set cursor => left son
                else:                               # self.left is None, skip
                    continue

            elif cursor.flag_trav & 0X04 == 0:      # if bit2 == 0,
                cursor.flag_trav |= 0X04            # set bit2 = 1
                if cursor.right != None:
                    cursor = cursor.right           # set cursor => right son
                else:
                    continue
        KD_Node.cur_trav = cursor

#        print("cursor:", KD_Node.cur_trav, "self flag:", KD_Node.cur_trav.flag_trav)

        return KD_Node.cur_trav


def CreateKDT(node=None, data=None, father=None):
    """
    TODO: DOC FOR CreateKDT
    INPUT: node, 
           data, [ (3,5), (2,4), (1,1) ]
           father, the father
    OUTPUT: 
    """
    # variance for each dimension, the biggest is desirable.
    dim = D
    var = np.var(data, axis=0)
    split = np.argmax(var)
    # Calc out the position of current node. Here the middle.
    pos = int(len(data)/2)
    # Using current split plane to split current data
    pos_list = np.argpartition(data[:,split], pos)
    point = data[pos_list[pos]]
    """
    print procedure
    print("#"*20)
    print("data:",list(data))
    print("split:",split,",",data[:,split])
    print("pos list: ",pos_list)
    print("pos:",pos)
    print("the node data:",data[pos_list[pos]] )
    """
    """
    print procedure
    print("DEBUG: LEFT=", data[pos_list[:pos]] )
    print("DEBUG: RIGHT=", data[pos_list[(pos+1):]] )
    """
    node = KD_Node(point, split, father=father)
    if len(data) > 1:
        if len(data[pos_list[:pos]]) != 0:
            node.left = CreateKDT(node.left, data[pos_list[:pos]], node)
        if len(data[pos_list[(pos+1):]]) != 0:
            node.right = CreateKDT(node.right, data[pos_list[(pos+1):]], node)
    return node

def get_split_pos(data, split):
    """return the position to split in data."""
    pos = len(data)/2
    return 

def preorder(node, depth=-1):
    """
    Preorder a KD node
    """
    if node:
        s = '#' + '-'*50 + '#\n'
        s += 'Node:' + str(node) + '\n'
        s += 'Point:' + str(node.point) \
                      + ', Flag: ' + str(bin(node.flag_trav)) \
                      + ', Cursor:' + str(KD_Node.cur_trav) \
                      + '\n'
        s += "Father:" + str(node.father) + '\n'
        s += "L:" + str(node.left) + '\n'
        s += "R:" + str(node.right)
        print(s)
        if node.left:
            preorder(node.left)
        if node.right:
            preorder(node.right)

def draw_KDT(kd):
    """
    Draw a plot in which each of data determined by a point and draw the classifying plane.
    """
    plt.figure(figsize=(6,6))
    plt.xlabel("$x^{(1)}$")
    plt.ylabel("$x^{(2)}$")
    plt.title("Machine Learning: KD Tree")
    plt.xlim(0,6)
    plt.ylim(0,6)
    ax = plt.gca()
    ax.set_aspect(1)
    for node in kd:
        plt.scatter( node.point[0], node.point[1], color='g' )
    plt.show()
    pass

def find_knn(root, x):
    pass


def main():
    kd = None
    kd = CreateKDT(kd, X)

    #preorder(kd)
    for node in kd:
        print( '*' * 50 )
        print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
        print( '*' * 50 )
    #preorder(kd)
    kd.clear_trav()
    #preorder(kd)
    draw_KDT(kd)

if __name__ == "__main__":
    main()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章