目錄
1.導入包
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
2.導入數據並創建DataFrame
object DecisionTreeClassificationExampleMl {
case class Iris(features: Vector, label: String) //注意:需寫在main外面
def main(args: Array[String]): Unit = {
//!!!注意:如果在Windows上執行,指定Hadoop的Home
System.setProperty("hadoop.home.dir", "D:\\temp\\hadoop-2.4.1\\hadoop-2.4.1")
//不打印日誌
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
// 創建sparksession對象
val spark = SparkSession.builder()
.master("local")
.appName("DTML")
.getOrCreate()
// 使用case class 創建DataFrame
import spark.implicits._
val data = spark.sparkContext.textFile("D:\\temp\\iris.txt")
.map(_.split(","))
.map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble, p(3).toDouble),p(4))).toDF()
// 需要生成視圖,纔可執行SQL語句
data.createOrReplaceTempView("iris")
val df = spark.sql("select * from iris")
3.劃分數據集,定義模型框架
注意:劃分數據集的時候如果想每次劃分不一樣,則不指定seed參數。
//我們把數據集隨機分成訓練集和測試集,其中訓練集佔70%。
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed=0)
//分別獲取標籤列和特徵列,進行索引。
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(df)
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(df)
//將預測的類別轉回字符型。
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
//定義決策樹模型。
val dtClassifier = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
4.用pipline將訓練步驟串聯,訓練模型
//在pipeline中進行設置
val pipelinedClassifier = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
//訓練決策樹模型
val modelClassifier = pipelinedClassifier.fit(trainingData)
5.在測試集上預測,查看部分結果
//進行預測
val predictionsClassifier = modelClassifier.transform(testData)
//查看部分預測的結果
predictionsClassifier.select("predictedLabel", "label", "features").show(20)
6.評估模型,打印樹模型
/*
* 評估模型的兩種寫法,一種自己算,另一種調用MulticlassClassificattionEvaluator,推薦第二種
* */
// Evaluate model on test instances and compute test error
val testErr = predictionsClassifier.filter($"predictedLabel" !== $"label").count().toDouble
val all = testData.count().toDouble
println("Total data is " + all + ", wrong data is "+ testErr + ". Test Error = " + testErr/all)
// 推薦下面這種方法
// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictionsClassifier)
println("Test Error = " + (1.0 - accuracy))
// print DecisionTree classification model
val dtModel = modelClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + dtModel.toDebugString)
spark.stop()
}
}