【ML從入門到入土系列04】決策樹

1 概述

傳統的決策樹算法包括ID3算法、C4.5算法以及CART算法。三者主要的區別在於特徵選擇準則不同。ID3算法選擇特徵的依據是信息增益,C4.5是信息增益比,而CART則是Gini指數。決策樹算法的基本流程如下圖所示。
在這裏插入圖片描述

2 ID3

2.1 理論

離散屬性aa的取值{a1,a2,a3,aV}\left\{a^{1}, a^{2}, a^{3}, \ldots a^{V}\right\}
DvD^{v}DD中在aa上取值=ava^{v}的樣本集合
以屬性aa對數據集DD進行劃分所獲得的信息增益爲:
在這裏插入圖片描述
其中,Ent(D)=k=1Ypklog2pk\operatorname{Ent}(D)=-\sum_{k=1}^{|\mathcal{Y}|} p_{k} \log _{2} p_{k},值越小純度越高。

2.2 代碼

import numpy as np
import pandas as pd
from math import log

# 信息熵
def entropy(ele):
    probs = [ele.count(i) / len(ele) for i in set(ele)]
    entropy = -sum([prob * log(prob, 2) for prob in probs])
    return entropy

# 數據劃分
def split_dataframe(data, col):
    unique_values = data[col].unique()
    result_dict = {elem: pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]
    return result_dict
    
# 選擇最佳特徵:信息增益
def choose_best_col(data, label):
    entropy_D = entropy(data[label].tolist())
    cols = [col for col in data.columns if col not in [label]]
    max_value, best_col = -999, None
    max_splited = None
    for col in cols:
        splited_set = split_dataframe(data, col)
        entropy_DA = 0
        for subset_col, subset in splited_set.items():
            entropy_Di = entropy(subset[label].tolist())
            entropy_DA += len(subset) / len(data) * entropy_Di
        info_gain = entropy_D - entropy_DA

        if info_gain > max_value:
            max_value, best_col = info_gain, col
            max_splited = splited_set
    return max_value, best_col, max_splited

# 創建ID3類
class ID3Tree:

    class Node:
        def __init__(self, name):
            self.name = name
            self.connections = {}

        def connect(self, label, node):
            self.connections[label] = node

    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")

    def print_tree(self, node, tabs):
        print(tabs + node.name)
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")

    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)

    def construct(self, parent_node, parent_connection_label, input_data, columns):
        max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label)

        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col]

        for splited_value, splited_data in max_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

if __name__ == '__main__':
	df = pd.read_csv('../data.csv', dtype={'windy':'str'})
	id3 = ID3Tree(df, 'play')
	id3.construct_tree()
	id3.print_tree(id3.root, '')

3 C4.5

3.1 理論

ID3的問題是對可取值數目較多的屬性有所偏好;故C4.5採用了信息增益率選擇特徵。計算公式如下:

 Gain-ratio (D,a)=Gain(D,a)IV(a)\text{ Gain-ratio }\left(D, a\right)=\frac{\text{Gain}\left(D, a\right)} {\text{IV}\left(a\right)}

其中,IV(a)=v=1VDvDlog2DvD\mathrm{IV}(a)=-\sum_{v=1}^{V} \frac{\left|D^{v}\right|}{|D|} \log _{2} \frac{\left|D^{v}\right|}{|D|}

4 CART

4.1 理論

CART算法包括特徵選擇、決策樹生成和決策樹剪枝三個部分, CART算法主要包括迴歸樹和分類樹兩種。迴歸樹特徵選擇準則用的是平方誤差最小準則,分類樹特徵選擇準則用的是基尼指數。此外,剪枝是決策樹算法的一種正則化手段。

  1. 迴歸樹
    在這裏插入圖片描述

  2. 分類樹
    基尼係數計算公式如下:
    Gini(D)=1k=1Ypk2 Gini-index (D,a)=v=1VDvDGini(Dr)\operatorname{Gini}(D)=1-\sum_{k=1}^{|\mathcal{Y}|} p_{k}^{2} \quad \text { Gini-index }(D, a)=\sum_{v=1}^{V} \frac{\left|D^{v}\right|}{|D|} \operatorname{Gini}\left(D^{r}\right)

  3. 剪枝
    剪枝就是通過主動去掉一些分支來降低過擬合風險,可分爲預剪枝與後剪枝。

  • 預剪枝
    在決策樹生成過程中,在劃分節點時,若該節點的劃分沒有提高其在驗證集上的準確率,則不進行劃分
  • 後剪枝
    後剪枝決策樹先生成一棵完整的決策樹,再從底往頂進行剪枝處理。

4.2 代碼

import numpy as np
import pandas as pd

# 基尼係數
def gini(nums):
    probs = [nums.count(i)/len(nums) for i in set(nums)]
    gini = sum([p*(1-p) for p in probs])
    return gini

# 劃分數據
def split_dataframe(data, col):
    unique_values = data[col].unique()
    result_dict = {elem : pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]
    return result_dict

# 選擇最佳特徵:基尼係數
def choose_best_col(df, label):
    gini_D = gini(df[label].tolist())
    cols = [col for col in df.columns if col not in [label]]
    min_value, best_col = 999, None
    min_splited = None
    for col in cols:
        splited_set = split_dataframe(df, col)
        gini_DA = 0
        for subset_col, subset in splited_set.items():
            gini_Di = gini(subset[label].tolist())
            gini_DA += len(subset)/len(df) * gini_Di
        
        if gini_DA < min_value:
            min_value, best_col = gini_DA, col
            min_splited = splited_set
    return min_value, best_col, min_splited

# 創建CART類
class CartTree:    

    class Node:        
        def __init__(self, name):
            self.name = name
            self.connections = {}    
            
        def connect(self, label, node):
            self.connections[label] = node    
        
    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")    
    
    def print_tree(self, node, tabs):
        print(tabs + node.name)        
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")    
    
    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)    
    
    def construct(self, parent_node, parent_connection_label, input_data, columns):
        min_value, best_col, min_splited = choose_best_col(input_data[columns], self.label)   
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)            
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col]        
        for splited_value, splited_data in min_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

if __name__ == '__main__':
	df = pd.read_csv('../data.csv', dtype={'windy':'str'})
	tree1 = CartTree(df, 'play')
	tree1.construct_tree()
	tree1.print_tree(tree1.root, "")

參考

理論:周志華《機器學習》,李航《統計學習方法》
代碼:https://github.com/luwill/machine-learning-code-writing

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章