K近鄰算法的KD樹實現

#K近鄰算法的KD樹實現
#lichunyu-2020.6.3

import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt

class Node:
    def __init__(self):
        self.left = None
        self.right = None
        self.value = [] #vector

class Neighbour:
    def __init__(self, k):
        self.k = k
        self.nk = [(None, float('inf'))] * self.k
        
    def getMaxDist(self):
        return max(self.nk, key=lambda elem: elem[1])[1]

    def update_max(self, item): #用 item 替換 nk 中距離最大的元素
        for i in range(self.k):
            if self.nk[i][1] == self.getMaxDist():
                self.nk[i] = item
                break

    def show(self):
        self.nk.sort(key=lambda elem: elem[1])
        for i in range(self.k):
            # print(self.nk[i][0].value, self.nk[i][1])
            plt.plot(self.nk[i][0].value[0], self.nk[i][0].value[1], 'rx', c='g', label='nk')

class KDTree:
    def __init__(self, data, neighbour, p = 2):
        self.root = None
        self.dimension = len(data[0]) - 1 #x[0] = [x1, x2, y]
        self.root = self.construct(data, 0)
        self.p = p  #距離變量
        self.neighbour = neighbour # k個鄰域

    def construct(self, data, cur_d): # cur_d -> 當前座標維度
        if(len(data) == 0):
            return None

        data = data[data[:, cur_d].argsort()]    # 按照當前維度的座標排序
        mid = len(data) // 2
        node = Node()
        node.value = data[mid]
        next_d = (cur_d + 1) % self.dimension
        node.left = self.construct(data[0 : mid, :], next_d)
        node.right = self.construct(data[mid + 1 :, :], next_d)

        return node

    def search(self, node, pos, cur_d = 0): # kd-tree 查找最近鄰
        if pos[cur_d] <= node.value[cur_d]:
            nearer_node = node.left
            further_node = node.right
        else:
            nearer_node = node.right
            further_node = node.left

        next_d = (cur_d + 1) % self.dimension
        if nearer_node:
            self.search(nearer_node, pos, next_d)
        
        #當前 node 與 pos 的距離 ---> 是否更近
        distance = self._Lp(node.value[:-1], pos, self.p)
        if distance < self.neighbour.getMaxDist():
            self.neighbour.update_max((node, distance))

        #另一個子節點的區域是否與超球體相交 $$超球體以neighbour中最大距離爲半徑
        if further_node and (further_node.value[cur_d] - pos[cur_d] < self.neighbour.getMaxDist()): #如果相交
            self.search(further_node, pos, next_d)  #在另一個結點的區域內找更近的

    def _Lp(self, x1, x2, p):
        sum = 0
        for i in range(len(x1)):
            sum += math.pow(abs(x1[i] - x2[i]), p)
        return math.pow(sum, 1 / p)

class KNN:
    def __init__(self, data, k = 1, p = 2):
        self.neighbour = Neighbour(k)
        self.kdTree = KDTree(data, self.neighbour, p)
        
    def predict(self, pos):
        self.kdTree.search(self.kdTree.root, pos)
        return self.judge(self.kdTree.neighbour.nk)

    def judge(self, nk):
        dict_class_times = {}
        for each in nk: #統計k近鄰 中每個 class 出現次數
            belong = each[0].value[-1]
            if belong in dict_class_times:    #y[index] --> class
                dict_class_times[belong] += 1 
            else:
                dict_class_times[belong] = 1
        
        return max(dict_class_times, key=lambda elem: dict_class_times[elem])


def test():
    #data
    from sklearn.datasets import load_iris
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['label'] = iris.target
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
    data = np.array(df.iloc[:100, [0, 1, -1]])
    
    # data = np.array([[2,3,0],[5,4,0],[9,6,0],[4,7,0],[8,1,0],[7,2,0]])
    # plt.scatter(data[:, 0], data[:, 1], c='y', label='1')
    
    knn = KNN(data, k = 10, p = 2)
    pos = [5.1, 2.8]
    belong = knn.predict(pos)
    print(pos, "belongs to ", belong)

    knn.kdTree.neighbour.show()
    plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], c='b', label='0')
    plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], c='y', label='1')
    plt.plot(pos[0], pos[1], 'b*', label='test_point')
    plt.xlabel('sepal length')
    plt.ylabel('sepal width')
    plt.show()

if __name__ == "__main__":
    test()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章