K-means算法

偶然接觸到了K-means,在理解之後寫下博客記錄。

首先,K-means是一種無監督學習聚類算法。什麼是聚類算法,聚類就是對大量未標註的數據集,按數據存在的內部特徵特徵劃分爲多個不同的類別。

K-means

算法接受參數k,然後將事先輸入的n個數據劃分爲k個聚類。其中滿足條件:同一聚類對象相似度高,不同聚類對象相似度較小。

算法思想

k個點爲中心聚類,對靠近的對象類歸類,通過迭代,逐漸更新各聚類中心。

算法描述

(1)適當選擇c個類的初始中心

(2)在第k個迭代中,對任意樣本,求其到各中心的距離,將樣本歸到距離最短的中心所在的類

(3)利用均值等方法更新中心值

(4)對於所有的聚類中心,如果利用(2)(3)步驟迭代法更新後,值不變,則迭代結束

對於以上文字來說可能不太容易理解,接下來博主會放圖結合文字來說明K-means

簡單的準備幾個點作爲數據演示:x1(1,1),x2(3,2),x3(1,3),x4(4,7),x5(6,7),x6(5,10),x7(10,4),x8(11,6),x9(12,4)

將這些點繪製出來(左上角爲原點):

將這些點用矩陣表示,每一列表示一個點,第0行表示X值,第一行表示Y值,第二行表示數據所屬類別:

執行算法描述的第一步:

適當選擇c個類的初始中心,我們打算將數據分爲三個類,則需要任意選取三個點作爲聚類中心,我們取(1,1),(3,2),(1,3)作爲初始聚類中心

執行算法描述第二步:

在第k個迭代中,對任意樣本,求其到各中心的距離,將樣本歸到距離最短的中心所在的類,即需要求出事先準備好的9個點到每個聚類中心點的距離,我們構建D矩陣用於存儲距離

D[i][j]表示第(i+1)個聚類點到第(j+1)個數據點的距離,例如:D[1][4] = 5.831 表示第2個聚類點(3,2)到第5個數據點(6,7)的距離

D矩陣每一行表示所有點到同一聚類中心的距離,每一列表示同一點到不同聚類中心的距離,我們將每一列進行比較,得到最小的距離,即可判斷該點屬於哪一類。我們構建G矩陣,將D矩陣按列進行比較,找到最小值位置,對應G矩陣位置 置爲1,例如:D矩陣第一列0,2.2361,2,最小值爲0,對應下標爲D[0][0],則G[0][0] = 1, D矩陣第四列,6.7082,5.099,5,最小值爲5,對應下標D[2][3],則G[2][3] = 1,得到G矩陣,

每一列表示一個數據,第幾行爲1表示該數據屬於哪一類,例如x1的第一行爲1,表示x1當前屬於第一類,x2,x5,x7,x8,x9的第二行爲1,則這些數據當前屬於第二類,x3,x4,x6的第三行爲1,則這些數據屬於第三類。

執行算法描述第三步:

利用均值方法更新中心值,我們將同一類的數據取均值代替原來的聚類點。當前第一類只有一個數據點(1,1),所以聚類點1爲(1/1,1/1)=(1,1),第二類數據有x2,x5,x7,x8,x9,則新的聚類點爲((3+6+10+11+12)/5,2+7+4+6+4)/5)=(8.4,4.6),同理,新的第三個聚類點爲((1+4+5)/3, (3+7+10)/3)=(3.333,6.666)

執行算法描述第四步:

對於所有的聚類中心,如果利用(2)(3)步驟迭代法更新後,值不變,則迭代結束

由於新的聚類點改變了,需要繼續迭代

繼續構建距離矩陣D,將所有點對新的聚類中心求距離

繼續構建G矩陣,找到數據點所屬類別

這次x1,x2,x3屬於第一類,x4,x5,x6屬於第三類,x7,x8,x9屬於第二類,利用均值方法更新中心值,新的聚類點爲(1.666 ,2.000)、 ( 11.000,  4.666)、(  5.000,  8.000),發現聚類點再次變化,繼續迭代

繼續構建距離矩陣D,將所有點對新的聚類中心求距離

繼續構建G矩陣,找到數據點所屬類別

新的聚類點爲(1.666 ,2.000)、 ( 11.000,  4.666)、(  5.000,  8.000),發現聚類點停止改變,停止迭代。

所以我們將x1,x2,x3分爲一類,x4,x5,x6分爲一類,x7,x8,x9分爲一類,將其繪製出來:

以上就是本人對K-means的理解

接下來爲C++、opencv對K-means測試的代碼

#include<iostream>
#include<opencv2\opencv.hpp>
#include<math.h>

using namespace std;
using namespace cv;
#define K 3
//找到兩點之間距離
float getDistance(Point2f A, Point2f B)
{
	float distance = 0.0;
	distance = sqrt(pow(A.x - B.x, 2) + pow(A.y - B.y,2));
	return distance;
}
//找出vector中的最小值
int getMinIndex(vector<float>data)
{
	float index = 0;
	float min = 10000;
	for (int i = 0; i < K; i++)
	{
		if (data[i] < min)
		{
			min = data[i];
			index = i;
		}
	}
	return index;
}
//找出vector中值爲1的下標
vector<int> getIndexIsOne(vector<float>data)
{
	vector<int>index;
	for (int i = 0; i < data.size(); i++)
	{
		if (data[i] == 1)
			index.push_back(i);
	}
	return index;
}
//是否停止迭代
bool shouldStop(vector<Point2f>oldCentroids, vector<Point2f>centroids, int iterations, int maxIt)
{
	if (iterations > maxIt)
		return true;
	return oldCentroids == centroids;

}
//更新數據類別
void updateLabels(Mat &dataset, vector<Point2f>points, vector<Point2f>&centroids)
{
	//構建D0矩陣  K行N列,用與記錄每個點與聚類點的距離
	Mat D0 = Mat::zeros(Size(points.size(), K), CV_32F);
	//計算每個點與聚類點之間的距離,D0[i][j]表示第i個數據點與第j個聚類點的距離
	for (int i = 0; i < K; i++)
	{
		for (int j = 0; j < points.size(); j++)
		{
			D0.at<float>(i, j) = getDistance(points[j], centroids[i]);
		}
	}
	//構建G矩陣  K行N列,按列進行比較,找到最小值位置,對應G矩陣位置 置爲1
	Mat G0 = Mat::zeros(Size(points.size(), K), CV_32F);
	for (int i = 0; i < points.size(); i++)
	{
		Mat col;
		//獲取每一列數據後,使用reshape轉換成vector便於計算
		D0.colRange(i, i + 1).copyTo(col);
		//reshape(cn,row) 
		vector<float>colsVec(col.reshape(1, 1));
		G0.at<float>(getMinIndex(colsVec), i) = 1;
	}
	for (int i = 0; i < K; i++)
	{
		Mat row;
		//獲取每一行數據後,使用reshape轉換成vector便於計算
		G0.rowRange(i, i + 1).copyTo(row);
		vector<float>rowsVec(row.reshape(1, 1));
		vector<int>indexVec;
		indexVec = getIndexIsOne(rowsVec);
		int xSum = 0.0;
		int ySum = 0.0;
		//利用均值更新聚類點
		for (int j = 0; j < indexVec.size(); j++)
		{
			dataset.at<float>(2, indexVec[j]) = i*1.0;//bug
			xSum += points[indexVec[j]].x;
			ySum += points[indexVec[j]].y;
		}
		centroids[i].x = xSum*1.0 / indexVec.size();
		centroids[i].y = ySum*1.0 / indexVec.size();
	}
}
Mat Kmeans(vector<Point2f>points, int classification, int maxIt)
{
	//創建一個 3行,N列的數據,多出來的一行用於表示數據類別
	Mat dataset = Mat::zeros(Size(points.size(), 3), CV_32FC1);
	for (int i = 0; i < points.size(); i++)
	{
		dataset.at<float>(0, i) = points[i].x;
		dataset.at<float>(1, i) = points[i].y;
	}
	vector<Point2f>centroids(3);
	//初始化聚類點,任意取K個數據作爲初始數據
	for (int i = 0; i < K; i++)
	{
		centroids[i] = points[i];
	}
	int iterations = 0;
	//用於比較聚類點是否發生變化
	vector<Point2f>oldCentroids(3,Point(0,0));

	while (!shouldStop(oldCentroids, centroids, iterations, maxIt))
	{
		iterations++;
		oldCentroids.assign(centroids.begin(), centroids.end());

		updateLabels(dataset, points, centroids);
	}
	return dataset;
}
//繪圖
void DrawMat(Mat dataset,Mat &drawingBoard)
{
	
	for (int i = 0; i < dataset.cols; i++)
	{
		if (dataset.at<float>(2, i) == 0)
			//circle(drawingBoard,Point(dataset.at<float>(1, i),)
			drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(0, 0, 255);
		else if (dataset.at<float>(2, i) == 1)
			drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(0, 255, 0);
		else if (dataset.at<float>(2, i) == 2)
			drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(255, 0, 0);
	}
	imshow("散點分類圖",drawingBoard);
}
int main()
{

	Point2f x1 = Point2f(1, 1);
	Point2f x2 = Point2f(3, 2);
	Point2f x3 = Point2f(1, 3);
	Point2f x4 = Point2f(4, 7);
	Point2f x5 = Point2f(6, 7);
	Point2f x6 = Point2f(5, 10);
	Point2f x7 = Point2f(10, 4);
	Point2f x8 = Point2f(11, 6);
	Point2f x9 = Point2f(12, 4);
	//將數據放入容器中,便於計算
	vector<Point2f>points;
	points.push_back(x1);
	points.push_back(x2);
	points.push_back(x3);
	points.push_back(x4);
	points.push_back(x5);
	points.push_back(x6);
	points.push_back(x7);
	points.push_back(x8);
	points.push_back(x9);


	Mat dataset1 = Mat::zeros(20, 20, CV_8UC3);
	for (int i = 0; i < points.size(); i++)
	{
		dataset1.at<Vec3b>(points[i]) = Vec3b(255, 255, 255);
	}

	Mat dataset = Kmeans(points, K, 200);
	Mat drawingBoard = Mat::zeros(20, 20, CV_8UC3);
	DrawMat(dataset, drawingBoard);
	waitKey(0);
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章