決策樹
決策樹模型,適用於分類、迴歸。
簡單地理解決策樹呢,就是通過不斷地設置新的條件標準對當前的數據進行劃分,最後以實現把原始的雜亂的所有數據分類。
就像下面這個圖,如果輸入是一大堆追求一個妹子的漢子,妹子內心裏有個篩子,最後菇涼也就決定了和誰約(舉慄而已哦,不代表什麼~大家理解原理重要~~)
不難看出,構造決策樹的關鍵就在於劃分條件和終止條件的決定
一個屬性能不能作爲劃分條件要看用他來分類好不好,我們說原始信息是無序的,那麼他能不能很好地降低信息的無序性。
我們常用Gini不純度、錯誤率(Error)、熵(Entropy)來衡量信息的混亂程度,公式定義分別如下:
P(i)表示事件i發生的概率,這三個數越大說明數據越不純。
比較屬性的劃分效果的算法有C4.5、ID3。詳細的可以參考這篇博文在spark中終止條件可以由決策樹的構造方法
DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
的參數:最大深度maxDepth、最大劃分數(在構建節點時把數據分到多少個盒子中去)maxBins來決定
參數categoricalFeaturesInfo是一個映射表,用來指明哪些特徵是分類的,以及他們有多少個類。比如,特徵1是一個標籤爲1,0的二元特徵,特徵2是0,1,2的三元特徵,則傳遞{1: 2, 2: 3}。如果沒有特徵是分類的,數據是連續變量,那麼我們可以傳遞空表。
impurity表示結點的不純淨度測量,分類問題採用 gini或者entropy,而回歸必須用 variance。
決策樹的缺點是容易過擬合,導致訓練出來的模型對訓練集的擬合效果很好,對其他數據的效果卻有所下降。對深度和最大劃分數的設定就是爲了避免這種情況,當然,在下面我們還將接觸到決策樹的優化版:隨機森林,隨機森林就可以很好地處理這個問題。
實例
package linear;
import java.util.HashMap;
import java.util.Map;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
public class DecisionTreeRegression {
/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeRegression").setMaster("local[*]");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
//一、加載文件。libsvm文件格式形如 Label 1:value 2:value....
String datapath = "/home/monkeys/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//把70%的數據用做訓練集,剩下的爲測試集
JavaRDD<LabeledPoint> [] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
//二、設置參數:這裏用hashmap表徵連續變量
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "variance";
Integer maxDepth = 5;//最大深度
Integer maxBins = 32;//最大劃分數
//三、訓練模型:
final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
JavaPairRDD<Double, Double> predictionAndLabel = testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>(){
public Tuple2<Double, Double> call(LabeledPoint p){
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
}
);
//四、計算誤差:平方和的均值
Double testMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>(){
//@Override
public Double call(Tuple2<Double, Double> p1){
Double diff = p1._1() - p1._2();
return diff * diff;
}
}
).reduce(new Function2<Double, Double, Double>(){
public Double call(Double a, Double b){
return a + b;
}
}) / data.count();
System.out.println("Test Mean squared error: " + testMSE);
System.out.println("Learned regression tree model: \n" + model.toDebugString());
//model.save(jsc.sc(), "myDecisionTreeRegressionModel");
//DecisionTreeModel sameModel = DecisionTreeModel.load(jsc.sc(), "myDecisionTreeRegressionModel");
}
}