NearestNeighbors是Julia中一個效率比較高的KNN分類統計代碼庫,它提供了BallTree,KDTree等多種數據結構。
這裏使用KDTree結構搜索歐式距離最近的數據, 並繪製圖表。這裏仍然使用鳶尾花數據
代碼示例
using RDatasets
using DataFrames
using CSV
using NearestNeighbors
using Colors
using PyPlot
using PyCall
import PyPlot:plot
import NearestNeighbors.HyperSphere
@pyimport matplotlib.patches as patch
iris = dataset("datasets", "iris"); # load the data
features = collect(Matrix(iris[:, 1:4])'); # features to use for clustering
#要搜索的點
point = features[:,1]
#其他的點
features = features[:,2:end]
#創建樹
kdtree = KDTree(features)
#搜索多少個點
k =3
idxs, dists = knn(kdtree, point, k, true)
#搜索到的點的行索引
idxs
# 3-element Array{Int64,1}:
# 17
# 4
# 39
features[:,17]
features[:,4]
features[:,39]
#搜索到的點歐式距離
dists
# 3-element Array{Float64,1}:
# 0.09999999999999998
# 0.1414213562373093
# 0.14142135623730964
data = hcat(features[:,17],features[:,4],features[:,39])
data = data'
other = hcat(features[:,3:16],features[:,1:3],features[:,18:38],features[:,40:end])
other=other'
# 生成顏色圖譜,並繪製搜索到的點圖
cols = distinguishable_colors(4, RGB(0,0,0))
# 創建圖片
cfig = figure()
ax = cfig[:add_subplot](1,1,1)
ax[:set_aspect]("equal")
axis((0.0,9.0,0.0,3.0))
data
for row in eachrow(other)
plot(row[3], row[4], "*",color = (cols[2].r, cols[2].g, cols[2].b))
end
for row in eachrow(data)
plot(row[3], row[4], "*",color = (cols[3].r, cols[3].g, cols[3].b))
end
plot(point[3], point[4], "*",color = (cols[4].r, cols[4].g, cols[4].b))
title("iris")
cfig[:savefig]("iris.png")