1 概述
傳統的決策樹算法包括ID3算法、C4.5算法以及CART算法。三者主要的區別在於特徵選擇準則不同。ID3算法選擇特徵的依據是信息增益,C4.5是信息增益比,而CART則是Gini指數。決策樹算法的基本流程如下圖所示。
2 ID3
2.1 理論
離散屬性的取值
:中在上取值=的樣本集合
以屬性對數據集進行劃分所獲得的信息增益爲:
其中,,值越小純度越高。
2.2 代碼
import numpy as np
import pandas as pd
from math import log
# 信息熵
def entropy(ele):
probs = [ele.count(i) / len(ele) for i in set(ele)]
entropy = -sum([prob * log(prob, 2) for prob in probs])
return entropy
# 數據劃分
def split_dataframe(data, col):
unique_values = data[col].unique()
result_dict = {elem: pd.DataFrame for elem in unique_values}
for key in result_dict.keys():
result_dict[key] = data[:][data[col] == key]
return result_dict
# 選擇最佳特徵:信息增益
def choose_best_col(data, label):
entropy_D = entropy(data[label].tolist())
cols = [col for col in data.columns if col not in [label]]
max_value, best_col = -999, None
max_splited = None
for col in cols:
splited_set = split_dataframe(data, col)
entropy_DA = 0
for subset_col, subset in splited_set.items():
entropy_Di = entropy(subset[label].tolist())
entropy_DA += len(subset) / len(data) * entropy_Di
info_gain = entropy_D - entropy_DA
if info_gain > max_value:
max_value, best_col = info_gain, col
max_splited = splited_set
return max_value, best_col, max_splited
# 創建ID3類
class ID3Tree:
class Node:
def __init__(self, name):
self.name = name
self.connections = {}
def connect(self, label, node):
self.connections[label] = node
def __init__(self, data, label):
self.columns = data.columns
self.data = data
self.label = label
self.root = self.Node("Root")
def print_tree(self, node, tabs):
print(tabs + node.name)
for connection, child_node in node.connections.items():
print(tabs + "\t" + "(" + connection + ")")
self.print_tree(child_node, tabs + "\t\t")
def construct_tree(self):
self.construct(self.root, "", self.data, self.columns)
def construct(self, parent_node, parent_connection_label, input_data, columns):
max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label)
if not best_col:
node = self.Node(input_data[self.label].iloc[0])
parent_node.connect(parent_connection_label, node)
return
node = self.Node(best_col)
parent_node.connect(parent_connection_label, node)
new_columns = [col for col in columns if col != best_col]
for splited_value, splited_data in max_splited.items():
self.construct(node, splited_value, splited_data, new_columns)
if __name__ == '__main__':
df = pd.read_csv('../data.csv', dtype={'windy':'str'})
id3 = ID3Tree(df, 'play')
id3.construct_tree()
id3.print_tree(id3.root, '')
3 C4.5
3.1 理論
ID3的問題是對可取值數目較多的屬性有所偏好;故C4.5採用了信息增益率選擇特徵。計算公式如下:
其中,
4 CART
4.1 理論
CART算法包括特徵選擇、決策樹生成和決策樹剪枝三個部分, CART算法主要包括迴歸樹和分類樹兩種。迴歸樹特徵選擇準則用的是平方誤差最小準則,分類樹特徵選擇準則用的是基尼指數。此外,剪枝是決策樹算法的一種正則化手段。
-
迴歸樹
-
分類樹
基尼係數計算公式如下:
-
剪枝
剪枝就是通過主動去掉一些分支來降低過擬合風險,可分爲預剪枝與後剪枝。
- 預剪枝
在決策樹生成過程中,在劃分節點時,若該節點的劃分沒有提高其在驗證集上的準確率,則不進行劃分 - 後剪枝
後剪枝決策樹先生成一棵完整的決策樹,再從底往頂進行剪枝處理。
4.2 代碼
import numpy as np
import pandas as pd
# 基尼係數
def gini(nums):
probs = [nums.count(i)/len(nums) for i in set(nums)]
gini = sum([p*(1-p) for p in probs])
return gini
# 劃分數據
def split_dataframe(data, col):
unique_values = data[col].unique()
result_dict = {elem : pd.DataFrame for elem in unique_values}
for key in result_dict.keys():
result_dict[key] = data[:][data[col] == key]
return result_dict
# 選擇最佳特徵:基尼係數
def choose_best_col(df, label):
gini_D = gini(df[label].tolist())
cols = [col for col in df.columns if col not in [label]]
min_value, best_col = 999, None
min_splited = None
for col in cols:
splited_set = split_dataframe(df, col)
gini_DA = 0
for subset_col, subset in splited_set.items():
gini_Di = gini(subset[label].tolist())
gini_DA += len(subset)/len(df) * gini_Di
if gini_DA < min_value:
min_value, best_col = gini_DA, col
min_splited = splited_set
return min_value, best_col, min_splited
# 創建CART類
class CartTree:
class Node:
def __init__(self, name):
self.name = name
self.connections = {}
def connect(self, label, node):
self.connections[label] = node
def __init__(self, data, label):
self.columns = data.columns
self.data = data
self.label = label
self.root = self.Node("Root")
def print_tree(self, node, tabs):
print(tabs + node.name)
for connection, child_node in node.connections.items():
print(tabs + "\t" + "(" + connection + ")")
self.print_tree(child_node, tabs + "\t\t")
def construct_tree(self):
self.construct(self.root, "", self.data, self.columns)
def construct(self, parent_node, parent_connection_label, input_data, columns):
min_value, best_col, min_splited = choose_best_col(input_data[columns], self.label)
if not best_col:
node = self.Node(input_data[self.label].iloc[0])
parent_node.connect(parent_connection_label, node)
return
node = self.Node(best_col)
parent_node.connect(parent_connection_label, node)
new_columns = [col for col in columns if col != best_col]
for splited_value, splited_data in min_splited.items():
self.construct(node, splited_value, splited_data, new_columns)
if __name__ == '__main__':
df = pd.read_csv('../data.csv', dtype={'windy':'str'})
tree1 = CartTree(df, 'play')
tree1.construct_tree()
tree1.print_tree(tree1.root, "")
參考
理論:周志華《機器學習》,李航《統計學習方法》
代碼:https://github.com/luwill/machine-learning-code-writing