《情景劇》
看客:KD樹?不用遞歸遍歷?也不使用隊列?別逗了小子。
我: 您別不信。我就是這麼厲害,就是這麼自信。
看客:那你敢不敢放出來看看?
我: 那您瞧好嘞……
正文
在本文中,我要給大家介紹一個用Python迭代器實現的KD樹(二叉樹)遍歷的方法。迭代的介紹,可以參考我寫的這篇文章。
網址:http://blog.csdn.net/weixin_37722024/article/details/62424311
在這裏,我就不細說迭代器的寫法了,而只是對遍歷時需要用到的幾個成員變量和成員函數__next__() 中的遍歷的邏輯作出說明。
Python迭代器的幫助:https://docs.python.org/3/tutorial/classes.html#iterators
一、KD樹的類的構造
構造KD樹也就是二叉樹的方法比較簡單,就是首先構造根節點,然後用遞歸地方法構造根節點的左子樹和右子樹。這不在本文的討論範圍。
KD_Node 類中,定義了二叉樹必不可少的節點值、左子樹和右子樹,以及KD樹需要的每個節點的split 。除了這幾個基本成員以外,新增了三個成員變量:一個是所有類實例公用的cur_trav ,這個成員變量保存了一個類型爲KD_Node 的節點,用作搜索時的遊標;類實例的成員變量flag_trav ,是int 型,保存每一個節點的在遍歷過程中的狀態;類實例的成員變量father ,是KD_Node 型,保存了每個節點的父節點。這樣的結構,每個節點增加一個
int 變量,類增加一個遊標節點由所有節點(也是類的實例)共享,增加的內存開銷有限,卻帶來了極大的靈活性,也使得代碼變得簡潔。用同樣的遍歷過程,不止能獲得各個節點的值,還可以用matplotlib繪製出KD樹的空間分割圖,甚至模擬KNN的搜索過程。我在draw_KDT函數裏就想這麼做,還沒做完。
二、迭代過程
前面說過了搜索的邏輯都寫在
__next__() 裏了,具體的流程還請看一下代碼。我在開始寫迭代器的時候是想用遞歸寫__next__() 來着,試過,不行,放棄,另起樓竈,就鼓搗出這麼個玩意兒。其中,用作遍歷的幾個變量,在KD樹剛剛建立好後都保持初始值。在開始遍歷時,拷貝根節點到遊標
cur_trav ,這就是當前搜索的節點。在搜索的過程中,根據遊標所在的當前節點的狀態flag_trav ,來確定是返回當前節點,或者返回他的左子節點,又或者返回右子節點。所以這個狀態只使用了最低3個bit,bit0代表當前節點是否已經遍歷過,bit1代表當前節點的左子節點是否遍歷過,bit2代表當前節點的右子節點是否遍歷過。而father 是爲了能夠返回父節點,隨後能夠搜索到另外一個子樹。節點狀態
flag_trav 的變化規律:1、初始值爲0,表示當前節點、左子和右子都未被搜索過;2、算是中序遍歷,也就是說先搜當前節點,然後左,最後右。所以[bit2,bit1,bit0]的狀態只有:[000], [001],[011],[111]四種狀態;3、如果子節點不存在,則跳過。左右子節點都不存在,就跳到當前節點的父節點。
三、代碼
直接把幾段關鍵的貼出來,懶得看長篇代碼的同學,就看着幾段意思意思吧。
代碼段1,類的聲明:
class KD_Node:
cur_trav = None # cursor for traversal.
def __init__(self, point=None, split=None, L=None, R=None, father=None):
"""
initiate a kd tree.
point: datum of this node
split: split plane for this node
L: left son
R: right son
father: father of this node, if root it's None
"""
self.point = point
self.split = split
self.left = L
self.right = R
self.father = father
self.flag_trav = 0 # traversal flag.
# bit 0 is notation for itself
# bit 1 is for its left son
# bit 2 is for its right son
......
代碼段2,__next__() 函數
def __next__(self):
# with non-iteration traverse the tree
cursor = None
if KD_Node.cur_trav == None: # First time to use cur_trav, initiate.
KD_Node.cur_trav = self
cursor = KD_Node.cur_trav
while 1:
if cursor.flag_trav & 0X07 == 0X7: # any node has flag with
# value=3
# that states a completion
# of traversal.
if cursor.father == None:
raise StopIteration
else:
cursor = cursor.father
elif cursor.flag_trav & 0X01 == 0: # if bit0 == 0,
cursor.flag_trav |= 0X01 # set bit0 = 1
#cursor = cursor # not need. set cursor => self
break # BREAK! return current.
elif cursor.flag_trav & 0X02 == 0: # if bit1==0, bit2==0
cursor.flag_trav |= 0X02 # set bit1 of self
if cursor.left != None:
cursor = cursor.left # set cursor => left son
else: # self.left is None, skip
continue
elif cursor.flag_trav & 0X04 == 0: # if bit2 == 0,
cursor.flag_trav |= 0X04 # set bit2 = 1
if cursor.right != None:
cursor = cursor.right # set cursor => right son
else:
continue
KD_Node.cur_trav = cursor
return KD_Node.cur_trav
代碼段3,簡潔的遍歷
def main():
kd = None
kd = CreateKDT(kd, X)
for node in kd:
print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
完整代碼如下
import numpy as np
import matplotlib.pyplot as plt
"""
X, feature vectors
Y, class of X
X_with_class, I just use this to draw a graphic in the piece of code at \
the bottom.
D, dimension of each of vectors.
"""
# Construct initial to be classified data
X = np.array([ (3,5), (2,4), (1,1), (5,2), (1,5), (4,1) ])
Y = [ 'g', 'g', 'r', 'r', 'g', 'r' ]
X_with_class = [ [X[a,0],X[a,1],Y[a]] for a in range(len(X)) ]
D = 0
if len(X[0]):
D = len(X[0])
class KD_Node:
cur_trav = None # cursor for traversal.
def __init__(self, point=None, split=None, L=None, R=None, father=None):
"""
initiate a kd tree.
point: datum of this node
split: split plane for this node
L: left son
R: right son
father: father of this node, if root it's None
"""
self.point = point
self.split = split
self.left = L
self.right = R
self.father = father
self.flag_trav = 0 # traversal flag.
# bit 0 is notation for itself
# bit 1 is for its left son
# bit 2 is for its right son
def clear_trav(self):
KD_Node.cur_trav = None
self.flag_trav = 0
if self.left:
self.left.clear_trav()
if self.right:
self.right.clear_trav()
def __iter__(self):
return self
def __next__(self):
# with non-iteration traverse the tree
cursor = None
if KD_Node.cur_trav == None: # First time to use cur_trav, initiate.
KD_Node.cur_trav = self
cursor = KD_Node.cur_trav
while 1:
if cursor.flag_trav & 0X07 == 0X7: # any node has flag with
# value=3
# that states a completion
# of traversal.
if cursor.father == None:
raise StopIteration
else:
cursor = cursor.father
elif cursor.flag_trav & 0X01 == 0: # if bit0 == 0,
cursor.flag_trav |= 0X01 # set bit0 = 1
#cursor = cursor # not need. set cursor => self
break # BREAK! return current.
elif cursor.flag_trav & 0X02 == 0: # if bit1==0, bit2==0
cursor.flag_trav |= 0X02 # set bit1 of self
if cursor.left != None:
cursor = cursor.left # set cursor => left son
else: # self.left is None, skip
continue
elif cursor.flag_trav & 0X04 == 0: # if bit2 == 0,
cursor.flag_trav |= 0X04 # set bit2 = 1
if cursor.right != None:
cursor = cursor.right # set cursor => right son
else:
continue
KD_Node.cur_trav = cursor
# print("cursor:", KD_Node.cur_trav, "self flag:", KD_Node.cur_trav.flag_trav)
return KD_Node.cur_trav
def CreateKDT(node=None, data=None, father=None):
"""
TODO: DOC FOR CreateKDT
INPUT: node,
data, [ (3,5), (2,4), (1,1) ]
father, the father
OUTPUT:
"""
# variance for each dimension, the biggest is desirable.
dim = D
var = np.var(data, axis=0)
split = np.argmax(var)
# Calc out the position of current node. Here the middle.
pos = int(len(data)/2)
# Using current split plane to split current data
pos_list = np.argpartition(data[:,split], pos)
point = data[pos_list[pos]]
"""
print procedure
print("#"*20)
print("data:",list(data))
print("split:",split,",",data[:,split])
print("pos list: ",pos_list)
print("pos:",pos)
print("the node data:",data[pos_list[pos]] )
"""
"""
print procedure
print("DEBUG: LEFT=", data[pos_list[:pos]] )
print("DEBUG: RIGHT=", data[pos_list[(pos+1):]] )
"""
node = KD_Node(point, split, father=father)
if len(data) > 1:
if len(data[pos_list[:pos]]) != 0:
node.left = CreateKDT(node.left, data[pos_list[:pos]], node)
if len(data[pos_list[(pos+1):]]) != 0:
node.right = CreateKDT(node.right, data[pos_list[(pos+1):]], node)
return node
def get_split_pos(data, split):
"""return the position to split in data."""
pos = len(data)/2
return
def preorder(node, depth=-1):
"""
Preorder a KD node
"""
if node:
s = '#' + '-'*50 + '#\n'
s += 'Node:' + str(node) + '\n'
s += 'Point:' + str(node.point) \
+ ', Flag: ' + str(bin(node.flag_trav)) \
+ ', Cursor:' + str(KD_Node.cur_trav) \
+ '\n'
s += "Father:" + str(node.father) + '\n'
s += "L:" + str(node.left) + '\n'
s += "R:" + str(node.right)
print(s)
if node.left:
preorder(node.left)
if node.right:
preorder(node.right)
def draw_KDT(kd):
"""
Draw a plot in which each of data determined by a point and draw the classifying plane.
"""
plt.figure(figsize=(6,6))
plt.xlabel("$x^{(1)}$")
plt.ylabel("$x^{(2)}$")
plt.title("Machine Learning: KD Tree")
plt.xlim(0,6)
plt.ylim(0,6)
ax = plt.gca()
ax.set_aspect(1)
for node in kd:
plt.scatter( node.point[0], node.point[1], color='g' )
plt.show()
pass
def find_knn(root, x):
pass
def main():
kd = None
kd = CreateKDT(kd, X)
#preorder(kd)
for node in kd:
print( '*' * 50 )
print( '*' + ' '*17, node.point, node.split, ' '*22 + '*' )
print( '*' * 50 )
#preorder(kd)
kd.clear_trav()
#preorder(kd)
draw_KDT(kd)
if __name__ == "__main__":
main()