Spark2.1.0_ml 決策樹分類模型

目錄

1.導入包

2.導入數據並創建DataFrame

3.劃分數據集,定義模型框架

4.用pipline將訓練步驟串聯,訓練模型

5.在測試集上預測,查看部分結果

6.評估模型,打印樹模型

7.運行結果


1.導入包

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

2.導入數據並創建DataFrame

object DecisionTreeClassificationExampleMl {
  case class Iris(features: Vector, label: String)  //注意:需寫在main外面
  def main(args: Array[String]): Unit = {
    //!!!注意:如果在Windows上執行,指定Hadoop的Home
    System.setProperty("hadoop.home.dir", "D:\\temp\\hadoop-2.4.1\\hadoop-2.4.1")
    //不打印日誌
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    // 創建sparksession對象
    val spark = SparkSession.builder()
      .master("local")
      .appName("DTML")
      .getOrCreate()

    // 使用case class 創建DataFrame
    import spark.implicits._

    val data = spark.sparkContext.textFile("D:\\temp\\iris.txt")
      .map(_.split(","))
      .map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble, p(3).toDouble),p(4))).toDF()

    // 需要生成視圖,纔可執行SQL語句
    data.createOrReplaceTempView("iris")
    val df = spark.sql("select * from iris")

3.劃分數據集,定義模型框架

注意:劃分數據集的時候如果想每次劃分不一樣,則不指定seed參數。

    //我們把數據集隨機分成訓練集和測試集,其中訓練集佔70%。
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed=0)
    //分別獲取標籤列和特徵列,進行索引。
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(df)
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(df)
    //將預測的類別轉回字符型。
     val labelConverter = new IndexToString()
       .setInputCol("prediction")
       .setOutputCol("predictedLabel")
       .setLabels(labelIndexer.labels)
    //定義決策樹模型。
    val dtClassifier = new DecisionTreeClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")

4.用pipline將訓練步驟串聯,訓練模型

    //在pipeline中進行設置
    val pipelinedClassifier = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
    //訓練決策樹模型
    val modelClassifier = pipelinedClassifier.fit(trainingData)

5.在測試集上預測,查看部分結果

    //進行預測
    val predictionsClassifier = modelClassifier.transform(testData)
    //查看部分預測的結果
    predictionsClassifier.select("predictedLabel", "label", "features").show(20)

6.評估模型,打印樹模型

    /*
    * 評估模型的兩種寫法,一種自己算,另一種調用MulticlassClassificattionEvaluator,推薦第二種
    * */
    // Evaluate model on test instances and compute test error
    val testErr = predictionsClassifier.filter($"predictedLabel" !== $"label").count().toDouble
    val all = testData.count().toDouble
    println("Total data is " + all + ", wrong data is "+ testErr + ". Test Error = " + testErr/all)

    // 推薦下面這種方法
    // Select (prediction, true label) and compute test error.
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictionsClassifier)
    println("Test Error = " + (1.0 - accuracy))
    
    // print DecisionTree classification model
    val dtModel = modelClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
    println("Learned classification tree model:\n" + dtModel.toDebugString)
    
    spark.stop()
  }
}

7.運行結果

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章