決策樹分類算法:ID3 & C4.5 & CART

分類的概念

分類的基本任務就是根據給定的一系列屬性集,最後去判別它屬於的類型!

比如我們現在需要去給動物分類,類別可選項爲哺乳類,爬行類,鳥類,魚類,或者兩棲類。給你一些屬性集如這個動物的體溫,是否胎生,是否爲水生動物,是否爲飛行動物,是否有腿,是否冬眠。

現在分類的基本任務就是,已知一個動物的屬性集,判斷或預測這個動物屬於哪一種類別?

決策樹分類法

簡述

從根節點開始,每個分支都會包含一個屬性測試條件,用於分開具有不同特性的記錄,最終到達葉節點,即可得到類標號。

具體過程

從根節點開始,從衆多的屬性集裏邊選擇一個屬性,由這個屬性把數據進行分類(該屬性的一個值則形成一個孩子節點),得到這個根節點的多個孩子節點。
再由這些孩子節點開始選擇剩餘的屬性來進行分類,遞歸的進行下去,直至所有屬性都已經使用完畢!

問題

(1). 如何確定選擇哪個屬性來作爲測試條件?
某個分類的熵值定義爲:
這裏寫圖片描述
所以對於一個屬性來說,分類後的熵值越低說明數據的純度越高,這個正是我們想要得到的結果,故使用這個指標來判斷屬性的優先選擇權。
(2). 如何終止遞歸?避免過度擬合?
數據中可能會出現一些離羣點,這會造成決策樹在進行決策的過程中對這樣的數據非常敏感,所以我們可以使用一個閾值來終止遞歸(即當前的節點下數據標號的純度已經滿足某個閾值)。

關鍵代碼

private void buildDecisionTree(AttrNode node, String parentAttrValue, String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
        node.setParentAttrValue(parentAttrValue);

        String attrName = "";
        double gainValue = 0;
        double tempValue = 0;

        if(remainAttr.size() == 1) {
            System.out.println("attr null");
            return ;
        }
        // 在所有剩餘屬性集裏選擇一個信息增益最大的屬性
        for(int i = 0;i < remainAttr.size();i ++) {
            if(isID3) {
                // ID3算法計算信息增益
                tempValue = computeGain(remainData, remainAttr.get(i));
            } else {
                // C4.5算法計算信息增益比
                tempValue = computeGainRatio(remainData, remainAttr.get(i));
            }

            if(tempValue > gainValue) {
                gainValue = tempValue;
                // 找到最佳的屬性
                attrName = remainAttr.get(i);
            }
        }

        node.setAttrName(attrName);
        // 得到這個屬性下的所有取值 去進一步拓展孩子節點
        ArrayList<String> valueTypes = attrValue.get(attrName);
        // 移除掉這個已經使用了的屬性
        remainAttr.remove(attrName);

        AttrNode[] childNode = new AttrNode[valueTypes.size()];
        String[][] rData;
        // 遍歷這個屬性的所有取值
        for(int i = 0;i < valueTypes.size();i ++) {
            // 把該種取值下的數據提取出來
            rData = removeData(remainData, attrName, valueTypes.get(i));

            childNode[i] = new AttrNode();
            boolean sameClass = true;
            ArrayList<String> indexArray = new ArrayList<>();
            // 遍歷剩餘的數據
            for(int k = 1;k < rData.length;k ++) {
                indexArray.add(rData[k][0]);
                if (!rData[k][attrNames.length - 1].equals(rData[1][attrNames.length - 1])) {
                    sameClass = false;
                    break;
                }
            }

            if(!sameClass) {
                buildDecisionTree(childNode[i], valueTypes.get(i), rData, remainAttr, isID3);
            } else {
                // 如果數據中標號全部相同(或者是達到了某個閾值)停止遞歸
                childNode[i].setParentAttrValue(valueTypes.get(i));
                childNode[i].setChildDataIndex(indexArray);
            }
        }
        // 遞歸完成後,給頭結點設定孩子節點
        node.setChildAttrNode(childNode);
    }

總結

決策樹分類算法是屬於監督學習的算法,也就是他需要初始的數據來進行訓練,去得到一個經過訓練的模型。然後這個模型就可以用來根據屬性集預測標號。它的不足在於它無法進行增量計算,也就是當新增一些已知的數據集的時候,只有重新結合之前的數據來重新構建決策樹,而無法僅僅利用增量來構建強化。但是這類算法的思路非常簡單,理解起來也不難。

引申

CART算法(Classification And Regression Tree):也是一種決策樹分類算法,與之前的C4.5和ID3不同的是:
1. 每個非葉子節點都有兩個孩子節點,這也就意味着劃分條件僅爲等於和不等於某個值,來對數據進行劃分空間。
2. CART算法對於屬性的值採用的是基於Gini係數值的方式做比較,舉一個網上的一個例子:(劃分條件爲體溫是否恆溫)
比如體溫爲恆溫時包含哺乳類5個、鳥類2個,則:

Gini(left_child)=1(57)2(27)2=2049

體溫爲非恆溫時包含爬行類3個、魚類3個、兩棲類2個,則
Gini(right_child)=1(38)2(38)2(28)2=4264

所以如果按照“體溫爲恆溫和非恆溫”進行劃分的話,我們得到GINI的增益(類比信息增益):
Gini(A)=7152049+8154264

最好的劃分就是使得GINI_Gain最小的劃分
通過比較每個屬性的最小的gini指數值,作爲最後的結果。
3. CART算法在把數據進行分類之後,會對樹進行一個剪枝,常用的用前剪枝和後剪枝法,而常見的後剪枝發包括代價複雜度剪枝,悲觀誤差剪枝等等.代價複雜度剪枝的公式爲:
r=R(t)R(Tt)NTt1

其中R(t) 表示如果對節點進行剪枝的話,最終的誤差代價 = 該節點的誤差率 * 該節點數據數目所佔比例,R(Tt) 表示如果沒有進行剪枝的話,這顆子樹所有的葉子節點的誤差代價之和,NTt 表示該子樹葉子節點的個數。

scikit-learn使用

from sklearn import tree
# 有一些可選擇參數 可以查看文檔
clf = tree.DecisionTreeClassifier()

clf.fit(features_train, lables_train)

pre = clf.prediction(features_test)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章