Spark K-Means 算法例子

k-means算法是以空間的點距離爲基準,隨機或者按照一定規則選擇幾個中心點數據,計算每個點到該幾個中心點的距離,按照距離值最近歸爲一類的原則,把空間所有的點歸爲初始化的幾個中心,稱之爲中心簇。
然後,找到每個中心簇的中心,再次計算空間所有的點到新的中心點的舉例並歸類,以此不斷迭代,直到達到迭代次數或者點中心不再變化爲止。

kmeans_data.txt中的數據

0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2

package spark;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;

// $example on$
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;


public class JavaKMeansExample {

    public static void main(String[] args) {

        Logger logger = Logger.getLogger(JavaKMeansExample.class);
        // 設置日誌的等級 並關閉jetty容器的日誌
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
        Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);

        SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("JavaKMeansExample");
        JavaSparkContext jsc = new JavaSparkContext(conf);

        String path = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/kmeans_data.txt";
        JavaRDD<String> data = jsc.textFile(path);
        JavaRDD<Vector> parsedData = data.map(s -> {
            String[] sarray = s.split(" ");
            double[] values = new double[sarray.length];
            for (int i = 0; i < sarray.length; i++) {
                values[i] = Double.parseDouble(sarray[i]);
            }
            return Vectors.dense(values);
        });
        parsedData.cache();

        int numClusters = 2;
        int numIterations = 20;
        int runs = 10;
        /**
         * KMeans.train(RDD<Vector> data, int k, int maxIterations, int runs, String initializationMode, long seed) data 進行聚類的數據 k
         * 初始的中心點個數 maxIterations 迭代次數 
         * runs 運行次數 
         * initializationMode 初始中心點的選擇方式, 目前支持隨機選 "random" or "k-means||"。默認是 K-means|| 
         * seed 集羣初始化時的隨機種子。
         */
        KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations, runs);
        // 輸出聚類的中心
        System.out.println("Cluster centers:");
        for (Vector center : clusters.clusterCenters()) {
            System.out.println(" " + center);
        }

        // 本次聚類操作的收斂性,此值越低越好
        double cost = clusters.computeCost(parsedData.rdd());
        System.out.println("Cost: " + cost);

        double WSSSE = clusters.computeCost(parsedData.rdd());
        System.out.println("Within Set Sum of Squared Errors = " + WSSSE);

        // 預測並輸出輸出每組數據對應的中心
        parsedData.foreach(f -> {
            System.out.print(f + "\n");
            System.out.println(clusters.predict(f));
        });
        // 預測數據屬於哪個中心點
        int centerIndex = clusters.predict(Vectors.dense(new double[] {3.6, 4.7, 7.1}));//中心點的索引
        System.out.println("預測數據 (3.6, 4.7, 7.1)屬於中心[" + centerIndex + "]:" + clusters.clusterCenters()[centerIndex]);

        centerIndex = clusters.predict(Vectors.dense(new double[] {1.1, 0.7, 0.3}));
        System.out.println("預測數據 (1.1,0.7, 0.3)屬於中心[" + centerIndex + "]:" + clusters.clusterCenters()[centerIndex]);
        jsc.stop();
    }
}

... 

參考

發佈了32 篇原創文章 · 獲贊 2 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章