書接上回
在上篇中,我們介紹了XGBoost的原生接口使用方法,以及sklearn版本的接口。本篇我們再結合Scala/Spark來聊聊,以體現XGBoost在工程上的易用性。
Spark是基於Scala原生語言開發的一個分佈式迭代計算平臺,其中MLLib模塊包括了很多機器學習算法包(但比起Sklearn來肯定還是少的)。
Scala 是一門面向對象+函數式JVM語言,需要編譯後才能執行,但它提供了像Python那樣的交互式編程方式,調試代碼非常方便。推薦有好奇心的同學去了解下。
Scala + XGBoost
官方DEMO
官方網站上的一段DEMO程序如下:
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.XGBoost
object XGBoostScalaExample {
def main(args: Array[String]) {
// read trainining data, available at xgboost/demo/data
val trainData =
new DMatrix("/path/to/agaricus.txt.train")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
// run prediction
val predTrain = model.predict(trainData)
// save model to the file.
model.saveModel("/local/path/to/model")
}
下面我們來解剖下這段小程序,以做到以點帶面地瞭解其使用細節。
環境配置
既然Scala是一門編譯型語言,我們先需要搞清楚怎麼在裏面去使用到XBGoost。對於Java而言,我們可以手動下載軟件的JAR包後導入工程,也可以通過配置Maven依賴來自動下載依賴的JAR包。而Scala天生地依賴於JVM生態,幾乎所有java類JAR包都適用於它。
在[https://github.com/dmlc/xgboost/tree/master/jvm-packages] 這個頁面上我們可以找到適用於JVM的Maven依賴:
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>latest_version_num</version>
</dependency>
官網上說最新的版本號可以從https://github.com/dmlc/xgboost/releases上面去找,這裏我們看到穩定版本已經有1.0.0了。
但事實上,如果直接把上面依賴項中的latest_version_num改爲1.0.0,maven自動是查找不到的。所以我們人肉去maven repository 網站上(https://mvnrepository.com/)找一下更靠譜,會發現這裏其實更新到0.90版本。
點擊進去後,獲得如下的依賴項:
<!-- https://mvnrepository.com/artifact/ml.dmlc/xgboost4j -->
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.90</version>
</dependency>
這下子再來看Maven工程就可以順利導入了,在筆者使用的IDEA編輯器界面上,自動導入了以下package:
準備工作完成後,我們就開始實際編碼體驗在Scala環境下的使用方法了。上面官方給出的DEMO有點過於簡單,我們會稍加展開說明。
上手幹
訓練數據我們仍像上篇那樣採用官方給出的預處理好的數據集,可從https://github.com/dmlc/xgboost/tree/master/demo/data找到agaricus.txt.train、agaricus.txt.test兩個數據集。https://xgboost.readthedocs.io/en/latest/tutorials/input_format.html這個頁面上有提到它使用的是LibSVM格式數據,而這種格式數據的標準規範就是:
<label> <index1>:<value1> <index2>:<value2>
這一點是我們在上篇中未曾提及的,這裏做一個補充說明,希望大家瞭解這是一種業界也較爲常用的數據交換格式(現實世界中只要有一個標準,遵守的人越多,數據交換就越方便)。
先貼出筆者自己的完整代碼:
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost}
object tryXGBoost {
def main(args:Array[String])= {
// 首先肯定還是想着怎麼讀入數據
val trainData = new DMatrix("~/Downloads/agaricus.txt.train")
val testData = new DMatrix("~/Downloads/agaricus.txt.test")
// 配置參數
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
// number of iterations
val round = 2
// 創建模型實例,並擬合數據
val model = XGBoost.train(trainData,paramMap,round)
// 預測數據
val predicts: Array[Array[Float]] = model.predict(testData)
val posPred = predicts.map(x => x(0))
predicts.take(10).foreach(println(_))
println("------------------------")
posPred.take(10).foreach(println(_))
// 計算準確率
// 以0.5爲閾值,如果預測的概率值大於0.5,則認爲正例
val labelPredicted: Array[Int] = predicts.map(x => if(x(0) > 0.5) 1 else 0)
// labelPredicted.take(10).foreach(println(_))
val trueLabel = testData.getLabel.map(_.toInt)
// trueLabel.take(10).foreach(println(_))
var correctNum = 0
for( i <- 0 to labelPredicted.length - 1){
if(labelPredicted(i) == trueLabel(i)){
correctNum += 1
}
}
val correctRatio = 1.0*correctNum/labelPredicted.length
println("準確率:" + correctRatio)
}
}
這裏有幾個地方解釋一下,
val predicts: Array[Array[Float]] = model.predict(testData)
這一句執行完成後,我們打印一下該變量,會發現它長成這樣:
可見這不是一個常見的整型或字符串。利用IDEA這個編輯器的功能,自動添加變量的類型(即: Array[Array[Float]]),可發現model.predict返回的每一個值都是一個向量,進一步測試可發現長度爲1,因爲我們通過
val posPred = predicts.map(x => x(0))
獲取到向量裏面的元素,再來打印posPred:
這樣就正常了嘛!只不過它是一個概率值,我們設定一個閾值後,即可得到分類爲1。
這樣一段小程序最終的準確率:0.9596523,可見即使是未經仔細調參的模型,準確度也還是不錯的。
https://github.com/dmlc/xgboost/tree/master/jvm-packages/xgboost4j-example這裏面提供了更多使用Scala API的例子,感興趣的同學可以自行去翻看。
Spark + XGBoost
上面我們是用Scala做的單機本地環境下的編程體驗,而Spark是基於分佈式集羣的,未來在實際工作中可能更多會是這樣的使用方式。
依賴項
XGBoost爲Spark專門提供了可依賴的Maven:
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>latest_version_num</version>
</dependency>
但因爲是Spark版本,還需要配置對於spark的依賴:
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.4.4</version>
</dependency>
以下是筆者實際跑的完整代碼,用的數據集是Iris數據集(可從http://archive.ics.uci.edu/ml/datasets/Iris上去下載)。
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
object tryXGBoostSpark {
def main(args:Array[String]): Unit ={
val spark = SparkSession.builder.master("local").appName("tryXGB").getOrCreate()
val schema = new StructType(Array(
StructField("SepalLengthCm", DoubleType, true),
StructField("SepalWidthCm", DoubleType, true),
StructField("PetalLengthCm", DoubleType, true),
StructField("PetalWidthCm", DoubleType, true),
StructField("Species", StringType, true)))
// 以適合於數據的格式來讀入
val rawInput: DataFrame = spark.read.schema(schema).csv("~/Downloads/Iris.csv")
// 將原字符型變量轉變爲序列值0,1,2這樣
val indexer = new StringIndexer()
.setInputCol("Species")
.setOutputCol("label")
//擬合數據來訓練
val labelTransformed = indexer.fit(rawInput).transform(rawInput).drop("Species")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("SepalLengthCm","SepalWidthCm","PetalLengthCm","PetalWidthCm"))
.setOutputCol("fetures")
val xgbInput = vectorAssembler.transform(labelTransformed).select("fetures", "label")
val paramMap = Map(
"eta" -> 0.1f ,
"num_class" -> 3,
"max_depth" -> 2 ,
"objective" -> "multi:softprob" ,
"num_round" -> 50,
"num_workers" -> 1
)
val xgbClassifier = new XGBoostClassifier(paramMap).setFeaturesCol("fetures").setLabelCol("label")
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
val predict = xgbClassificationModel.transform(xgbInput)
predict.show(20)
predict.select("label","prediction").distinct().show()
val correct = predict.select("label","prediction").where("label == prediction")
print("準確率: " + 1.0*correct.count() / predict.count())
spark.stop()
}
}
執行完成後輸出結果爲0.986666。
spark相關的基礎知識比較多,這裏暫不深入展開,感興趣的讀者可以在微信公衆號上留言交流,也可以加入我們的QQ羣交流,羣號1448524707。
結語
本篇我們繼續介紹XGBoost在工程上的實際使用,分別編寫了Scala、Spark的版本代碼,爲的是讓大家有一個上手練習的參考。這裏未對參數展開討論,且背後的算法原理還是需要繼續鑽研。
在下篇中,我們會一起讀一下XGBoost的paper原文,希望更進一步瞭解算法本身的原理。敬請繼續關注。
歡迎關注本人微信公衆號: