偶然接觸到了K-means,在理解之後寫下博客記錄。
首先,K-means是一種無監督學習的聚類算法。什麼是聚類算法,聚類就是對大量未標註的數據集,按數據存在的內部特徵特徵劃分爲多個不同的類別。
K-means
算法接受參數k,然後將事先輸入的n個數據劃分爲k個聚類。其中滿足條件:同一聚類對象相似度高,不同聚類對象相似度較小。
算法思想
k個點爲中心聚類,對靠近的對象類歸類,通過迭代,逐漸更新各聚類中心。
算法描述
(1)適當選擇c個類的初始中心
(2)在第k個迭代中,對任意樣本,求其到各中心的距離,將樣本歸到距離最短的中心所在的類
(3)利用均值等方法更新中心值
(4)對於所有的聚類中心,如果利用(2)(3)步驟迭代法更新後,值不變,則迭代結束
對於以上文字來說可能不太容易理解,接下來博主會放圖結合文字來說明K-means
簡單的準備幾個點作爲數據演示:x1(1,1),x2(3,2),x3(1,3),x4(4,7),x5(6,7),x6(5,10),x7(10,4),x8(11,6),x9(12,4)
將這些點繪製出來(左上角爲原點):
將這些點用矩陣表示,每一列表示一個點,第0行表示X值,第一行表示Y值,第二行表示數據所屬類別:
執行算法描述的第一步:
適當選擇c個類的初始中心,我們打算將數據分爲三個類,則需要任意選取三個點作爲聚類中心,我們取(1,1),(3,2),(1,3)作爲初始聚類中心
執行算法描述第二步:
在第k個迭代中,對任意樣本,求其到各中心的距離,將樣本歸到距離最短的中心所在的類,即需要求出事先準備好的9個點到每個聚類中心點的距離,我們構建D矩陣用於存儲距離
D[i][j]表示第(i+1)個聚類點到第(j+1)個數據點的距離,例如:D[1][4] = 5.831 表示第2個聚類點(3,2)到第5個數據點(6,7)的距離
D矩陣每一行表示所有點到同一聚類中心的距離,每一列表示同一點到不同聚類中心的距離,我們將每一列進行比較,得到最小的距離,即可判斷該點屬於哪一類。我們構建G矩陣,將D矩陣按列進行比較,找到最小值位置,對應G矩陣位置 置爲1,例如:D矩陣第一列0,2.2361,2,最小值爲0,對應下標爲D[0][0],則G[0][0] = 1, D矩陣第四列,6.7082,5.099,5,最小值爲5,對應下標D[2][3],則G[2][3] = 1,得到G矩陣,
每一列表示一個數據,第幾行爲1表示該數據屬於哪一類,例如x1的第一行爲1,表示x1當前屬於第一類,x2,x5,x7,x8,x9的第二行爲1,則這些數據當前屬於第二類,x3,x4,x6的第三行爲1,則這些數據屬於第三類。
執行算法描述第三步:
利用均值方法更新中心值,我們將同一類的數據取均值代替原來的聚類點。當前第一類只有一個數據點(1,1),所以聚類點1爲(1/1,1/1)=(1,1),第二類數據有x2,x5,x7,x8,x9,則新的聚類點爲((3+6+10+11+12)/5,2+7+4+6+4)/5)=(8.4,4.6),同理,新的第三個聚類點爲((1+4+5)/3, (3+7+10)/3)=(3.333,6.666)
執行算法描述第四步:
對於所有的聚類中心,如果利用(2)(3)步驟迭代法更新後,值不變,則迭代結束
由於新的聚類點改變了,需要繼續迭代
繼續構建距離矩陣D,將所有點對新的聚類中心求距離
繼續構建G矩陣,找到數據點所屬類別
這次x1,x2,x3屬於第一類,x4,x5,x6屬於第三類,x7,x8,x9屬於第二類,利用均值方法更新中心值,新的聚類點爲(1.666 ,2.000)、 ( 11.000, 4.666)、( 5.000, 8.000),發現聚類點再次變化,繼續迭代
繼續構建距離矩陣D,將所有點對新的聚類中心求距離
繼續構建G矩陣,找到數據點所屬類別
新的聚類點爲(1.666 ,2.000)、 ( 11.000, 4.666)、( 5.000, 8.000),發現聚類點停止改變,停止迭代。
所以我們將x1,x2,x3分爲一類,x4,x5,x6分爲一類,x7,x8,x9分爲一類,將其繪製出來:
以上就是本人對K-means的理解
接下來爲C++、opencv對K-means測試的代碼
#include<iostream>
#include<opencv2\opencv.hpp>
#include<math.h>
using namespace std;
using namespace cv;
#define K 3
//找到兩點之間距離
float getDistance(Point2f A, Point2f B)
{
float distance = 0.0;
distance = sqrt(pow(A.x - B.x, 2) + pow(A.y - B.y,2));
return distance;
}
//找出vector中的最小值
int getMinIndex(vector<float>data)
{
float index = 0;
float min = 10000;
for (int i = 0; i < K; i++)
{
if (data[i] < min)
{
min = data[i];
index = i;
}
}
return index;
}
//找出vector中值爲1的下標
vector<int> getIndexIsOne(vector<float>data)
{
vector<int>index;
for (int i = 0; i < data.size(); i++)
{
if (data[i] == 1)
index.push_back(i);
}
return index;
}
//是否停止迭代
bool shouldStop(vector<Point2f>oldCentroids, vector<Point2f>centroids, int iterations, int maxIt)
{
if (iterations > maxIt)
return true;
return oldCentroids == centroids;
}
//更新數據類別
void updateLabels(Mat &dataset, vector<Point2f>points, vector<Point2f>¢roids)
{
//構建D0矩陣 K行N列,用與記錄每個點與聚類點的距離
Mat D0 = Mat::zeros(Size(points.size(), K), CV_32F);
//計算每個點與聚類點之間的距離,D0[i][j]表示第i個數據點與第j個聚類點的距離
for (int i = 0; i < K; i++)
{
for (int j = 0; j < points.size(); j++)
{
D0.at<float>(i, j) = getDistance(points[j], centroids[i]);
}
}
//構建G矩陣 K行N列,按列進行比較,找到最小值位置,對應G矩陣位置 置爲1
Mat G0 = Mat::zeros(Size(points.size(), K), CV_32F);
for (int i = 0; i < points.size(); i++)
{
Mat col;
//獲取每一列數據後,使用reshape轉換成vector便於計算
D0.colRange(i, i + 1).copyTo(col);
//reshape(cn,row)
vector<float>colsVec(col.reshape(1, 1));
G0.at<float>(getMinIndex(colsVec), i) = 1;
}
for (int i = 0; i < K; i++)
{
Mat row;
//獲取每一行數據後,使用reshape轉換成vector便於計算
G0.rowRange(i, i + 1).copyTo(row);
vector<float>rowsVec(row.reshape(1, 1));
vector<int>indexVec;
indexVec = getIndexIsOne(rowsVec);
int xSum = 0.0;
int ySum = 0.0;
//利用均值更新聚類點
for (int j = 0; j < indexVec.size(); j++)
{
dataset.at<float>(2, indexVec[j]) = i*1.0;//bug
xSum += points[indexVec[j]].x;
ySum += points[indexVec[j]].y;
}
centroids[i].x = xSum*1.0 / indexVec.size();
centroids[i].y = ySum*1.0 / indexVec.size();
}
}
Mat Kmeans(vector<Point2f>points, int classification, int maxIt)
{
//創建一個 3行,N列的數據,多出來的一行用於表示數據類別
Mat dataset = Mat::zeros(Size(points.size(), 3), CV_32FC1);
for (int i = 0; i < points.size(); i++)
{
dataset.at<float>(0, i) = points[i].x;
dataset.at<float>(1, i) = points[i].y;
}
vector<Point2f>centroids(3);
//初始化聚類點,任意取K個數據作爲初始數據
for (int i = 0; i < K; i++)
{
centroids[i] = points[i];
}
int iterations = 0;
//用於比較聚類點是否發生變化
vector<Point2f>oldCentroids(3,Point(0,0));
while (!shouldStop(oldCentroids, centroids, iterations, maxIt))
{
iterations++;
oldCentroids.assign(centroids.begin(), centroids.end());
updateLabels(dataset, points, centroids);
}
return dataset;
}
//繪圖
void DrawMat(Mat dataset,Mat &drawingBoard)
{
for (int i = 0; i < dataset.cols; i++)
{
if (dataset.at<float>(2, i) == 0)
//circle(drawingBoard,Point(dataset.at<float>(1, i),)
drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(0, 0, 255);
else if (dataset.at<float>(2, i) == 1)
drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(0, 255, 0);
else if (dataset.at<float>(2, i) == 2)
drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(255, 0, 0);
}
imshow("散點分類圖",drawingBoard);
}
int main()
{
Point2f x1 = Point2f(1, 1);
Point2f x2 = Point2f(3, 2);
Point2f x3 = Point2f(1, 3);
Point2f x4 = Point2f(4, 7);
Point2f x5 = Point2f(6, 7);
Point2f x6 = Point2f(5, 10);
Point2f x7 = Point2f(10, 4);
Point2f x8 = Point2f(11, 6);
Point2f x9 = Point2f(12, 4);
//將數據放入容器中,便於計算
vector<Point2f>points;
points.push_back(x1);
points.push_back(x2);
points.push_back(x3);
points.push_back(x4);
points.push_back(x5);
points.push_back(x6);
points.push_back(x7);
points.push_back(x8);
points.push_back(x9);
Mat dataset1 = Mat::zeros(20, 20, CV_8UC3);
for (int i = 0; i < points.size(); i++)
{
dataset1.at<Vec3b>(points[i]) = Vec3b(255, 255, 255);
}
Mat dataset = Kmeans(points, K, 200);
Mat drawingBoard = Mat::zeros(20, 20, CV_8UC3);
DrawMat(dataset, drawingBoard);
waitKey(0);
return 0;
}