Spark MLlib中KMeans聚類算法的使用

本文首發於我的個人博客QIMING.INFO,轉載請帶上鍊接及署名。

KMeans是一種典型的聚類算法,本文通過代碼來演示用spark運行KMeans算法的一個小例子。

算法簡介

KMeans算法的基本思想是初始隨機給定K個簇中心,按照最鄰近原則把無標籤樣本點分到各個簇。然後按平均法重新計算各個簇的質心,從而確定新的簇心。一直迭代,直到簇心的移動距離小於某個給定的值或迭代次數達到閾值。

運行步驟

數據說明

數據格式爲:特徵1 特徵2 特徵3

0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2

代碼及說明

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.SparkSession

object KMeansTest {
  def main(args: Array[String]): Unit = {
    // 生成SparkSession對象
    val spark = new SparkSession.Builder().appName("KMeansTest").getOrCreate()
    //生成SparkContext對象
    val sc = spark.sparkContext
    //設置日誌輸出級別
    Logger.getRootLogger.setLevel(Level.WARN)


    // 裝載數據集
    val data = sc.textFile("/home/hadoop/ML_Data/input/kmeans_data.txt")
    val parsedData = data.map(s => Vectors.dense(s.split("\\s+").map(_.toDouble))).cache()

    // 將數據集聚類,4個類,50次迭代,進行模型訓練形成數據模型
    val numClusters = 4
    val numIterations = 50
    val runs = 20        // 執行20次取最優
    val model = KMeans.train(parsedData, numClusters, numIterations, runs)

    //打印輸出中心點座標
    val centers = model.clusterCenters
    println("centers")
    for (i <- 0 to centers.length - 1) {
      println(centers(i)(0) + "\t" + centers(i)(1))
    }

    // 誤差計算
    val WSSSE = model.computeCost(parsedData)
    println("Within Set Sum of Squared Errors = " + WSSSE)

    //給每個數據點進行標號,zippedData格式爲(數據的類別號,數據)
    val c = parsedData.map(x => model.predict(x))
    c.persist()
    val zippedData = c.zip(parsedData)

    //保存zippedData到本地
    zippedData.saveAsTextFile("/home/hadoop/ML_Data/output/KMeans_data")
  }
}

結果展示



將結果合成到一個文件中,方便查看:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章