分類的概念
分類的基本任務就是根據給定的一系列屬性集,最後去判別它屬於的類型!
比如我們現在需要去給動物分類,類別可選項爲哺乳類,爬行類,鳥類,魚類,或者兩棲類。給你一些屬性集如這個動物的體溫,是否胎生,是否爲水生動物,是否爲飛行動物,是否有腿,是否冬眠。
現在分類的基本任務就是,已知一個動物的屬性集,判斷或預測這個動物屬於哪一種類別?
決策樹分類法
簡述
從根節點開始,每個分支都會包含一個屬性測試條件,用於分開具有不同特性的記錄,最終到達葉節點,即可得到類標號。
具體過程
從根節點開始,從衆多的屬性集裏邊選擇一個屬性,由這個屬性把數據進行分類(該屬性的一個值則形成一個孩子節點),得到這個根節點的多個孩子節點。
再由這些孩子節點開始選擇剩餘的屬性來進行分類,遞歸的進行下去,直至所有屬性都已經使用完畢!
問題
(1). 如何確定選擇哪個屬性來作爲測試條件?
某個分類的熵值定義爲:
所以對於一個屬性來說,分類後的熵值越低說明數據的純度越高,這個正是我們想要得到的結果,故使用這個指標來判斷屬性的優先選擇權。
(2). 如何終止遞歸?避免過度擬合?
數據中可能會出現一些離羣點,這會造成決策樹在進行決策的過程中對這樣的數據非常敏感,所以我們可以使用一個閾值來終止遞歸(即當前的節點下數據標號的純度已經滿足某個閾值)。
關鍵代碼
private void buildDecisionTree(AttrNode node, String parentAttrValue, String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
node.setParentAttrValue(parentAttrValue);
String attrName = "";
double gainValue = 0;
double tempValue = 0;
if(remainAttr.size() == 1) {
System.out.println("attr null");
return ;
}
// 在所有剩餘屬性集裏選擇一個信息增益最大的屬性
for(int i = 0;i < remainAttr.size();i ++) {
if(isID3) {
// ID3算法計算信息增益
tempValue = computeGain(remainData, remainAttr.get(i));
} else {
// C4.5算法計算信息增益比
tempValue = computeGainRatio(remainData, remainAttr.get(i));
}
if(tempValue > gainValue) {
gainValue = tempValue;
// 找到最佳的屬性
attrName = remainAttr.get(i);
}
}
node.setAttrName(attrName);
// 得到這個屬性下的所有取值 去進一步拓展孩子節點
ArrayList<String> valueTypes = attrValue.get(attrName);
// 移除掉這個已經使用了的屬性
remainAttr.remove(attrName);
AttrNode[] childNode = new AttrNode[valueTypes.size()];
String[][] rData;
// 遍歷這個屬性的所有取值
for(int i = 0;i < valueTypes.size();i ++) {
// 把該種取值下的數據提取出來
rData = removeData(remainData, attrName, valueTypes.get(i));
childNode[i] = new AttrNode();
boolean sameClass = true;
ArrayList<String> indexArray = new ArrayList<>();
// 遍歷剩餘的數據
for(int k = 1;k < rData.length;k ++) {
indexArray.add(rData[k][0]);
if (!rData[k][attrNames.length - 1].equals(rData[1][attrNames.length - 1])) {
sameClass = false;
break;
}
}
if(!sameClass) {
buildDecisionTree(childNode[i], valueTypes.get(i), rData, remainAttr, isID3);
} else {
// 如果數據中標號全部相同(或者是達到了某個閾值)停止遞歸
childNode[i].setParentAttrValue(valueTypes.get(i));
childNode[i].setChildDataIndex(indexArray);
}
}
// 遞歸完成後,給頭結點設定孩子節點
node.setChildAttrNode(childNode);
}
總結
決策樹分類算法是屬於監督學習的算法,也就是他需要初始的數據來進行訓練,去得到一個經過訓練的模型。然後這個模型就可以用來根據屬性集預測標號。它的不足在於它無法進行增量計算,也就是當新增一些已知的數據集的時候,只有重新結合之前的數據來重新構建決策樹,而無法僅僅利用增量來構建強化。但是這類算法的思路非常簡單,理解起來也不難。
引申
CART算法(Classification And Regression Tree):也是一種決策樹分類算法,與之前的C4.5和ID3不同的是:
1. 每個非葉子節點都有兩個孩子節點,這也就意味着劃分條件僅爲等於和不等於某個值,來對數據進行劃分空間。
2. CART算法對於屬性的值採用的是基於Gini係數值的方式做比較,舉一個網上的一個例子:(劃分條件爲體溫是否恆溫)
比如體溫爲恆溫時包含哺乳類5個、鳥類2個,則:
Gini(left_child)=1−(57)2−(27)2=2049
體溫爲非恆溫時包含爬行類3個、魚類3個、兩棲類2個,則
Gini(right_child)=1−(38)2−(38)2−(28)2=4264
所以如果按照“體溫爲恆溫和非恆溫”進行劃分的話,我們得到GINI的增益(類比信息增益):
Gini(A)=715∗2049+815∗4264
最好的劃分就是使得GINI_Gain最小的劃分。
通過比較每個屬性的最小的gini指數值,作爲最後的結果。
3. CART算法在把數據進行分類之後,會對樹進行一個剪枝,常用的用前剪枝和後剪枝法,而常見的後剪枝發包括代價複雜度剪枝,悲觀誤差剪枝等等.代價複雜度剪枝的公式爲:
r=R(t)−R(Tt)NTt−1
其中R(t) 表示如果對節點進行剪枝的話,最終的誤差代價 = 該節點的誤差率 * 該節點數據數目所佔比例,R(Tt) 表示如果沒有進行剪枝的話,這顆子樹所有的葉子節點的誤差代價之和,NTt 表示該子樹葉子節點的個數。
scikit-learn使用
from sklearn import tree
# 有一些可選擇參數 可以查看文檔
clf = tree.DecisionTreeClassifier()
clf.fit(features_train, lables_train)
pre = clf.prediction(features_test)