Spark Mllib 迴歸學習筆記三(java):決策樹

決策樹

決策樹模型,適用於分類、迴歸。
簡單地理解決策樹呢,就是通過不斷地設置新的條件標準對當前的數據進行劃分,最後以實現把原始的雜亂的所有數據分類。

就像下面這個圖,如果輸入是一大堆追求一個妹子的漢子,妹子內心裏有個篩子,最後菇涼也就決定了和誰約(舉慄而已哦,不代表什麼~大家理解原理重要~~)

不難看出,構造決策樹的關鍵就在於劃分條件終止條件的決定

  • 一個屬性能不能作爲劃分條件要看用他來分類好不好,我們說原始信息是無序的,那麼他能不能很好地降低信息的無序性。

    我們常用Gini不純度錯誤率(Error)熵(Entropy)來衡量信息的混亂程度,公式定義分別如下:


    P(i)表示事件i發生的概率,這三個數越大說明數據越不純。
    比較屬性的劃分效果的算法有C4.5、ID3。詳細的可以參考這篇博文

  • 在spark中終止條件可以由決策樹的構造方法DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
    的參數:最大深度maxDepth、最大劃分數(在構建節點時把數據分到多少個盒子中去)maxBins來決定
    參數categoricalFeaturesInfo是一個映射表,用來指明哪些特徵是分類的,以及他們有多少個類。比如,特徵1是一個標籤爲1,0的二元特徵,特徵2是0,1,2的三元特徵,則傳遞{1: 2, 2: 3}。如果沒有特徵是分類的,數據是連續變量,那麼我們可以傳遞空表。
    impurity表示結點的不純淨度測量,分類問題採用 gini或者entropy,而回歸必須用 variance。
     
    決策樹的缺點是容易過擬合,導致訓練出來的模型對訓練集的擬合效果很好,對其他數據的效果卻有所下降。對深度和最大劃分數的設定就是爲了避免這種情況,當然,在下面我們還將接觸到決策樹的優化版:隨機森林,隨機森林就可以很好地處理這個問題。

實例

操作數據

package linear;

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
public class DecisionTreeRegression {

    /**
     * @param args
     */
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeRegression").setMaster("local[*]");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);
        //一、加載文件。libsvm文件格式形如 Label   1:value  2:value....
        String datapath = "/home/monkeys/sample_libsvm_data.txt";
        JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
        //把70%的數據用做訓練集,剩下的爲測試集
        JavaRDD<LabeledPoint> [] splits = data.randomSplit(new double[]{0.7, 0.3});
        JavaRDD<LabeledPoint> trainingData = splits[0];
        JavaRDD<LabeledPoint> testData = splits[1];

        //二、設置參數:這裏用hashmap表徵連續變量
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        String impurity = "variance";
        Integer maxDepth = 5;//最大深度
        Integer maxBins = 32;//最大劃分數

        //三、訓練模型:
        final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
        JavaPairRDD<Double, Double> predictionAndLabel = testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>(){
            public Tuple2<Double, Double> call(LabeledPoint p){
                return new Tuple2<Double,  Double>(model.predict(p.features()), p.label());
            }
        }
                );

        //四、計算誤差:平方和的均值
        Double testMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>(){
            //@Override
            public Double call(Tuple2<Double, Double> p1){
                Double diff = p1._1() - p1._2();
                return diff * diff;
            }
        }
        ).reduce(new Function2<Double, Double, Double>(){
                public Double call(Double a, Double b){
                    return a + b;
                }
        }) / data.count();

        System.out.println("Test Mean squared error: " + testMSE);
        System.out.println("Learned regression tree model: \n" + model.toDebugString());

        //model.save(jsc.sc(), "myDecisionTreeRegressionModel");
        //DecisionTreeModel sameModel = DecisionTreeModel.load(jsc.sc(), "myDecisionTreeRegressionModel");
        }
    }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章