K-means學習總結

K-means學習總結

前言

最近在看fasttext源碼,看到裏面壓縮用到kmeans方法,不得不說fasttext實現的比較繞,c++語言一方面,沒註釋一方面,代碼邏輯也有一點,理解確實困難,雖然看kmeans的原理並不複雜。所以去翻了scikit-learn的kmeans源碼,幫助理解消化。看的過程中發現kmeans有改進版本,本文不涉及,只關注最簡單的kmeans實現。

關鍵步驟

kmeans是一種聚類方法,主要是從N個樣本中分出K個簇(cluster),在每個簇中找出一箇中心(centroid),在每個簇中的樣本都與這個簇的中心點的歐式距離(euclidean distance)之和最小:

E=i=1kxCixμi2E=\sum_{i=1}^{k}\sum_{x\in{C_{i}}}||x-\mu_{i}||^{2}

k-means算法主要實現有以下幾個步驟:

  1. 中心點初始化。可以從樣本集中隨機抽取K樣本作爲初始值;或根據經驗值指定K箇中心點。
  2. 簇填充。遍歷樣本集,計算每個樣本與每個中心點的距離,並將樣本填充至距離最近的簇。
  3. 重新計算中心點。根據已劃分的簇,計算簇內所有樣本的均值,並作爲新的中心點。
  4. 重複2、3,直至達到指定的迭代次數,或者誤差滿足要求。

代碼實現

參考scikit-learn的代碼簡單實現了k-means算法(未考慮各種異常),加強理解。

# author: delta_hell
# date: 2020-06-20

from collections import deque
from math import sqrt

import numpy as np
from scipy.cluster.vq import vq, kmeans, whiten

# 計算樣本與中心點的距離
def _distL2(obs, guess):
    dist = 0
    for k in range(obs.shape[0]):
        tmp = obs[k] - guess[k]
        dist += tmp * tmp
    # fasttext中並未使用sqrt。這裏爲了與scikit-learn保持一致
    return sqrt(dist)

# 步驟2 之 遍歷計算所有樣本與中心點的距離
def _calc_dists(obs, code_book):
    out = np.zeros((obs.shape[0], code_book.shape[0]))
    for i in range(obs.shape[0]):
        for j in range(code_book.shape[0]):
            out[i,j] = _distL2(obs[i], code_book[j])
    return out

# 步驟2。
def _vq(obs, code_book):
    if obs.ndim == 1:
        obs = obs[:, np.newaxis]
        code_book = code_book[:, np.newaxis]
    # 遍歷計算距離。
    dist = _calc_dists(obs, code_book)
    # 找到距離每個樣本最近的中心點。
    code = dist.argmin(axis=1)
    # 每個樣本的最短距離
    min_dist = dist[np.arange(len(code)), code]
    return code, min_dist

# 步驟3。
def _update_cluster_means(obs, obs_code):
    # 找出所有中心點。(可以不使用sorted)
    centroids = sorted(set(obs_code.tolist()))
    code_book = []
    has_members = []
    for x in centroids:
        # 找出屬於同一個簇的樣本。(同一個中心點)
        filtered = obs[obs_code == x]
        if filtered.shape[0] < 1:
            has_members.append(False)
        # 計算簇內樣本均值
        means = filtered.mean(axis = 0)
        # 均值作爲新的中心點
        code_book.append(means)
        has_members.append(True)
    # has_members主要用於剔除沒有樣本的簇
    return np.array(code_book), has_members

# k-means算法實現
def _kmeans(obs, code_book, thresh):
    diff = np.inf
    prev_avg_dists = deque([diff], maxlen=2)
    # 與上次劃分的誤差是否大於閾值
    while diff > thresh:
        # 步驟2
        obs_code, distort = _vq(obs, code_book)
        # 所有簇的平均誤差
        prev_avg_dists.append(distort.mean(axis=-1))
        # 步驟3
        code_book, has_members = _update_cluster_means(obs, obs_code)
        code_book = code_book[has_members]
        # 兩次劃分的誤差
        diff = prev_avg_dists[0] - prev_avg_dists[1]

    return code_book, prev_avg_dists[1]

# 步驟1。
def _kpoints(data, k):
    idx = np.random.choice(data.shape[0], size=k, replace=False)
    return data[idx]

# 入口
# obs: 樣本集
# k_or_guess: 指定簇數K或者給出經驗值。
# iter: 指定迭代次數。
# thresh:誤差閾值。(兩次中心點劃分的誤差,如果小於該閾值,則說明當前劃分已經最優,不需要重新計算中心點,該次k-means算法終止。)
def my_kmeans(obs, k_or_guess, iter=20, thresh=1e-5):
	# 如果是指定中心點,則不迭代。
    if not np.isscalar(k_or_guess):
        return _kmeans(obs, k_or_guess, thresh=thresh)

    k = int(k_or_guess)
    best_dist = np.inf
    # 迭代
    for i in range(iter):
        # inital code book
        guess = _kpoints(obs, k)
        # 計算
        book, dist = _kmeans(obs, guess, thresh)
        # 選取最優中心點
        if dist < best_dist:
            best_dist = dist
            best_book = book
    return best_book, best_dist

代碼量比我想象的多了點。。。

測試驗證

寫完之後,單步DEBUG了幾個關鍵點,沒啥問題,最終結果也與scikit-learn的結果一致,應該沒啥毛病。

features  = np.array([[ 1.9,2.3],
                   [ 1.5,2.5],
                   [ 0.8,0.6],
                   [ 0.4,1.8],
                   [ 0.1,0.1],
                   [ 0.2,1.8],
                   [ 2.0,0.5],
                   [ 0.3,1.5],
                   [ 1.0,1.0]])
whitened = whiten(features)
print(whitened)
print("----------------------------")
book = np.array((whitened[0],whitened[2]))
print(book)
print("-----------------------------")
code, dist = kmeans(whitened,book)
print(code)
print(dist)
print("-----------------------------")
code, dist = my_kmeans(whitened, book)
print(code)
print(dist)

結語

scikit-learn文檔上關於k-means還有改進版本以及並行計算的內容,有時間再看吧。fasttext的實現與當前這個版本還有不同,後面寫fasttext再說吧。感嘆一下,看起來簡單的東西都不敢說真懂啊,慢慢積累吧。

附錄

  1. scikit-learn k-means
  2. K-Means
  3. k-means wiki
  4. K均值原理及實現(K-Means)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章