K均值聚類算法(Kmeans)講解及源碼實現
算法核心
K均值聚類的核心目標是將給定的數據集劃分成K個簇,並給出每個數據對應的簇中心點。算法的具體步驟描述如下。
- 數據預處理,如歸一化、離羣點處理等。
- 隨機選取K個簇中心,記爲。
- 定義代價函數:。
- 令爲迭代步數,重複下面過程直到收斂
- 對於每一個樣本,將其分配到距離最近的簇
- 對於每一個類簇,重新計算該類簇的中心
均值算法在迭代時,假設當前損失函數沒有達到最小值,那麼首先固定簇中心,調整每個樣例所屬的類別來讓函數減少;
然後固定,調整簇中心使減少。
這兩個過程交替循環,單調遞減:當遞減到最小值時,和也同時收斂。
源碼實現(含可視化)
導入包
import numpy as np
import matplotlib.pyplot as plt
數據預處理
設置地圖尺寸
# map 100*100
high = 100
width = 100
創建隨機數據
每一條數據的格式爲,列表初始化爲0,類別序數間隔1遞增
data = np.random.rand(100, 2)
data = data * [high, width]
data = np.hstack((data, np.zeros([100, 1])))
定義簇數目
# count of classes
classes = 5
定義距離函數,此處我們使用歐氏距離
def distance(point1, center):
return np.sqrt((point1[0] - center[0]) ** 2 + (point1[1] - center[1]) ** 2)
定義從類別到顏色的映射函數,即
def color(i):
global classes
return i * 255. / classes
定義主函數
先將plt設置爲連續作圖模式
然後隨機挑選簇中心點,並加入到中心點數組中
if __name__ == '__main__':
plt.ion()
# select center randomly
centers = np.random.randint(0, 100, [classes])
centers_data = []
for i in range(classes):
data[i][2] = i
centers_data.append(data[i])
先畫散點圖,且暫停0.5秒以顯示迭代中的聚類情況。
while True:
colors = [color(x) for x in data[:, 2]]
plt.scatter(data[:, 0], data[:, 1], c=colors)
plt.pause(0.5)
先後依次迭代更新每個點所對應的簇,和每個簇的中心點。
# caculate nearest center
for i in range(100):
distances = np.array([distance(data[i], center_data) for center_data in centers_data])
i_class = np.argmin(distances)
data[i][2] = i_class
# caculate new center
new_centers_data = np.zeros([classes, 2])
new_centers_count = np.zeros([classes])
for j in range(5):
for i in range(100):
if data[i][2] == j:
new_centers_count[j] += 1
new_centers_data[j] += data[i][0:2]
計算五個簇的中心點位置先後變化的最大值,其值小於1e-4(可自定義)時,跳出循環,停止迭代。
new_centers_data /= np.array([new_centers_count]).T
dist = np.max([distance(new_centers_data[i], centers_data[i]) for i in range(classes)])
print('max distance ', dist)
if dist < 1e-4:
break
在每次迭代的最後更新中心點數據
centers_data = new_centers_data
最後關閉連續作圖模式,並展示最後的圖畫,打印結束。
plt.ioff()
plt.show()
print('kmeans completed.')
效果
命令行
max distance 28.36595846126929
max distance 7.136259328045152
max distance 3.533885366585787
max distance 2.153654229308223
max distance 0.0
kmeans completed.
可視化過程
第1次迭代
第2次迭代
第3次迭代
第4次迭代
全部代碼
import numpy as np
import matplotlib.pyplot as plt
# map 100*100
high = 100
width = 100
# create random data
data = np.random.rand(100, 2)
data = data * [high, width]
data = np.hstack((data, np.zeros([100, 1])))
# count of classes
classes = 5
def distance(point1, center):
return np.sqrt((point1[0] - center[0]) ** 2 + (point1[1] - center[1]) ** 2)
def color(i):
global classes
return i * 255. / classes
if __name__ == '__main__':
plt.ion()
# select center randomly
centers = np.random.randint(0, 100, [classes])
centers_data = []
for i in range(classes):
data[i][2] = i
centers_data.append(data[i])
while True:
colors = [color(x) for x in data[:, 2]]
plt.scatter(data[:, 0], data[:, 1], c=colors)
plt.pause(0.5)
# caculate nearest center
for i in range(100):
distances = np.array([distance(data[i], center_data) for center_data in centers_data])
i_class = np.argmin(distances)
data[i][2] = i_class
# caculate new center
new_centers_data = np.zeros([classes, 2])
new_centers_count = np.zeros([classes])
for j in range(5):
for i in range(100):
if data[i][2] == j:
new_centers_count[j] += 1
new_centers_data[j] += data[i][0:2]
new_centers_data /= np.array([new_centers_count]).T
dist = np.max([distance(new_centers_data[i], centers_data[i]) for i in range(classes)])
print('max distance ', dist)
if dist < 1e-4:
break
centers_data = new_centers_data
plt.ioff()
plt.show()
print('kmeans completed.')