本文首發於我的個人博客QIMING.INFO,轉載請帶上鍊接及署名。
MLPC(Multilayer Perceptron Classifier)
,多層感知器分類器,是一種基於前饋人工神經網絡(ANN)的分類器。Spark中目前僅支持此種與神經網絡有關的算法,在org.apache.spark.ml
中(並非mllib)。本文通過代碼來演示用Spark運行MLPC的一個小例子。
算法簡介
多層感知器是一種多層的前饋神經網絡模型。
所謂前饋型神經網絡,指其從輸入層開始只接收前一層的輸入,並把計算結果輸出到後一層,並不會給前一層有所反饋,整個過程可以使用有向無環圖來表示。該類型的神經網絡由三層組成,分別是輸入層(Input Layer)
,一個或多個隱層(Hidden Layer)
,輸出層(Output Layer)
,如圖所示:
MLPC採用了BP(反向傳播,Back Propagation
) 算法,BP算法的學習目的是對網絡的連接權值進行調整,使得調整後的網絡對任一輸入都能得到所期望的輸出。BP 算法名稱裏的反向傳播指的是該算法在訓練網絡的過程中逐層反向傳遞誤差,逐一修改神經元間的連接權值,以使網絡對輸入信息經過計算後所得到的輸出能達到期望的誤差。
Spark的多層感知器隱層神經元使用sigmoid
函數作爲激活函數,輸出層使用的是softmax
函數。
MLPC可調的幾個重要參數:
- featuresCol:輸入數據 DataFrame 中指標特徵列的名稱。
- labelCol:輸入數據 DataFrame 中標籤列的名稱。
- layers:這個參數是一個整型數組類型,第一個元素需要和特徵向量的維度相等,最後一個元素需要訓練數據的標籤數相等,如 2 分類問題就寫 2。中間的元素有多少個就代表神經網絡有多少個隱層,元素的取值代表了該層的神經元的個數。例如val layers = (5,6,5,2)。
- maxIter:優化算法求解的最大迭代次數。默認值是 100。
- predictionCol:預測結果的列名稱。
運行步驟
數據說明
MLPC對數據源有嚴格要求,只能是以下兩種:
- DataFrame
使用DataFrame作爲數據源時必須指定DataFrame中的標籤列和特徵列; - LIBSVM格式文本文件
數據格式爲:標籤 特徵ID:特徵值 特徵ID:特徵值……
本例中採用了LIBSVM格式文本文件,數據如下:
[xuqm@cu01 ML_Data]$ cat input/sample_multiclass_classification_data.txt
1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333
1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667
0 1:0.166667 2:-0.416667 3:0.457627 4:0.5
……
……
……
2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667
0 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333
1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667
1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667
1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75
代碼及說明
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql.SparkSession
object MLPCTest {
def main(args: Array[String]): Unit = {
// 構建spark對象
val spark = SparkSession.builder.appName("MLPCTest").getOrCreate()
// 讀取以LIBSVM格式存儲的數據
val data = spark.read.format("libsvm").load("file:///home/xuqm/ML_Data/input/sample_multiclass_classification_data.txt")
// 拆分成訓練集和測試集
val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L)
val train = splits(0)
val test = splits(1)
// 指定神經網絡的圖層:
// 輸入層4個結點(即4個特徵);兩個隱藏層,隱藏結點數分別爲5和4;輸出層3個結點(即分爲3類)
val layers = Array[Int](4, 5, 4, 3)
// 建立MLPC訓練器並設置參數
val trainer = new MultilayerPerceptronClassifier().
setLayers(layers).
setBlockSize(128).
setSeed(1234L).
setMaxIter(100)
// 訓練模型
val model = trainer.fit(train)
// 用訓練好的模型預測測試集的結果
val result = model.transform(test)
val predictionAndLabels = result.select("prediction", "label")
// 計算誤差並輸出
val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels))
// 輸出結果
result.show(60,false)
}
}
結果展示
// 計算誤差並輸出
Test set accuracy = 0.9019607843137255
// 輸出結果
result.show(60,false)
+-----+---------------------------------------------------------+----------+
|label|features |prediction|
+-----+---------------------------------------------------------+----------+
|0.0 |(4,[0,1,2,3],[-0.666667,-0.583333,0.186441,0.333333]) |2.0 |
|0.0 |(4,[0,1,2,3],[-0.277778,-0.333333,0.322034,0.583333]) |0.0 |
|0.0 |(4,[0,1,2,3],[-0.222222,-0.583333,0.355932,0.583333]) |0.0 |
|0.0 |(4,[0,1,2,3],[-0.0555556,-0.833333,0.355932,0.166667]) |2.0 |
|0.0 |(4,[0,1,2,3],[-0.0555556,-0.166667,0.288136,0.416667]) |2.0 |
|0.0 |(4,[0,1,2,3],[-1.32455E-7,-0.166667,0.322034,0.416667]) |2.0 |
|0.0 |(4,[0,1,2,3],[0.111111,-0.583333,0.355932,0.5]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.222222,-0.166667,0.627119,0.75]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.333333,-0.583333,0.627119,0.416667]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.333333,-0.166667,0.423729,0.833333]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.388889,-0.166667,0.525424,0.666667]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.444444,-0.0833334,0.38983,0.833333]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.555555,-0.166667,0.661017,0.666667]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.722222,-0.333333,0.728813,0.5]) |0.0 |
|0.0 |(4,[0,1,2,3],[0.888889,-0.333333,0.932203,0.583333]) |0.0 |
|0.0 |(4,[0,1,2,3],[1.0,0.5,0.830508,0.583333]) |0.0 |
|0.0 |(4,[0,2,3],[0.166667,0.457627,0.833333]) |0.0 |
|0.0 |(4,[0,2,3],[0.388889,0.661017,0.833333]) |0.0 |
|1.0 |(4,[0,1,2,3],[-0.944444,-0.166667,-0.898305,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.722222,-0.166667,-0.864407,-0.833333]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.666667,-0.166667,-0.864407,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.666667,-0.0833334,-0.830508,-1.0]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.611111,0.166667,-0.79661,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.555556,0.166667,-0.830508,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.555556,0.5,-0.830508,-0.833333]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.555556,0.5,-0.79661,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.5,0.166667,-0.864407,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.5,0.75,-0.830508,-1.0]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.166667,-0.830508,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.166667,-0.762712,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.583333,-0.898305,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.388889,0.583333,-0.762712,-0.75]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.333333,0.25,-0.898305,-0.916667]) |1.0 |
|1.0 |(4,[0,1,2,3],[-0.166667,0.666667,-0.932203,-0.916667]) |1.0 |
|1.0 |(4,[0,2,3],[-0.833333,-0.864407,-0.916667]) |1.0 |
|1.0 |(4,[0,2,3],[-0.777778,-0.898305,-0.916667]) |1.0 |
|2.0 |(4,[0,1,2,3],[-0.611111,-1.0,-0.152542,-0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.555556,-0.583333,-0.322034,-0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.388889,-0.166667,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.333333,-0.666667,-0.0847458,-0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.333333,-0.666667,-0.0508475,-0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.277778,-0.166667,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.222222,-0.5,-0.152542,-0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.222222,-0.333333,0.0508474,-4.03573E-8])|2.0 |
|2.0 |(4,[0,1,2,3],[-0.111111,-0.166667,0.0847457,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-0.0555556,-0.25,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[-1.32455E-7,-0.25,0.254237,0.0833333]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.0555554,-0.833333,0.186441,0.166667]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.0555554,-0.25,0.118644,-4.03573E-8]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.111111,0.0833333,0.254237,0.25]) |2.0 |
|2.0 |(4,[0,1,2,3],[0.333333,-0.166667,0.355932,0.333333]) |0.0 |
+-----+---------------------------------------------------------+----------+