機器學習-Python手動實現決策樹

本文將記錄有關決策樹的相關內容

決策樹

樹模型是在日常工作中使用頻率最高的模型之一,因爲其較好的模型效果與良好的可解釋性經常作爲baseline模型使用,在平時使用時經常使用sklearn庫使用,最近遇到了一個單子需要完全手動實現決策樹相關內容,在此將其記錄下來,如果有同學在做相關內容請與我郵件聯繫[email protected]

在之前的中決策樹blog,我已經簡單的對決策樹的思路作了總結,在此簡單回顧下:

分類樹

分類樹的思路分爲:0,判斷是否到達終止條件(最大樹深,葉子最少樣本數,最大誤差),1根據某種劃分方法找到最佳劃分特徵,2根據這個特徵將樣本分割開,二叉樹就是二分,多叉樹就是多分,3對於分割之後的小樣本集重複第一步。
不同的分類算法的區別就在於第一步中對於特徵的劃分方法不同:信息增益、信息增益率、gini指數,在第二步中id3\c4.5採用的是多分類的方法及對於每個特徵的不同取值進行劃分子數據集,對於cart來說是對於某個特徵的某個值的“是否”來劃分數據集,是個二分類,而且會構造出一個非常深的二叉樹,所以cart過擬合是最嚴重的的

迴歸樹

迴歸樹其實也就是最小二乘擬合,通過不同的特徵將樣本空間劃分成了不同的小塊,對於每個小塊的數據輸出平均值。

剪枝

由於決策樹是十分依賴於數據分佈的,所以如果訓練集和測試集分佈不同會導致嚴重的過擬合,所以需要對過擬合的樹進行剪枝。剪枝的意義在於將一個複雜的樹變的相對簡單,利用這個簡答樹進行分類。

Python實現

from math import log
from collections import defaultdict,Counter
from sklearn.model_selection import train_test_split
import sys
import matplotlib.pyplot as plt
# import tree


class DecisionTree(object):
    '''
    Python純手工實現決策樹,只是爲了學習和更進一步的掌握決策樹的內容
    如果該代碼有錯誤或者有做相關工作的同學歡迎與我溝通學習
    Email:[email protected]
    CSDN:https://blog.csdn.net/qq_22235017
    GitHub:https://github.com/ZhaoLiang-GitHub
    '''

    def __init__(self,config):
        self.config = config

    def _get_data_num(self,data,index,feature):
        # 該子樹下有多少個訓練數據
        num = 0
        for x in data:
            if x[index] == feature:
                num +=1
        return num

    def _get_leaf_num(self,tree):
        # 該子樹下有多少個葉子
        num = 0
        key_list = list(tree.keys())
        for i in key_list:
            try:
                num+=self._get_leaf_num(tree[i])
            except AttributeError:  # 葉子的最後一個不是字典沒有keys()
                num += 1
                return num
        return num

    def _get_error_value(self,tree,data):
        next_features = list(tree.keys())
        num = 0
        for next_feature in next_features:
            next_tree = tree[next_feature]
            leaf_num = self._get_leaf_num(next_tree)
            error_value = 0
            for next_tree_key in list(next_tree.keys()):
                error_value += self.clac_err(next_tree[next_tree_key],data,labels)
            num += (float(leaf_num)/len(data)) * error_value
        return num

    def _leaf_max_label(self,tree):
        # 得到該子樹上的最多標籤作爲葉子標籤
        self.label_count = {}
        def __leaf_max_label(tree):
            if self._is_tree(tree):
                key = list(tree.keys())[0]
                sub_tree = tree[key]
                for i in list(sub_tree.keys()):
                    __leaf_max_label(sub_tree[i])
            else:
                if tree not in self.label_count.keys():
                        self.label_count[tree] = 1
                else:
                        self.label_count[tree] += 1
        __leaf_max_label(tree)
        self.label_count = sorted(self.label_count.items(), key= lambda x:x[1], reverse=True)  # 排序
        return self.label_count[0][0]

    def _is_tree(self,obj):
        return isinstance(obj,dict)

    def _classify(self,input_tree,data,labels) ->str :
        # 對一條數據進行分類,得到最終的分類標籤
        if not self._is_tree(input_tree):
            return input_tree
        feature = list(input_tree.keys())[0]
        firtdict = input_tree[feature]
        features_index = labels.index(feature)
        key = data[features_index]
        try:
                firtdict_or_value = firtdict[key]
        except :
                return firtdict[list(firtdict.keys())[0]]
        if self._is_tree(firtdict_or_value):
                label = self._classify(firtdict_or_value, data,labels)
        else:
                label = firtdict_or_value
        return label

    def _classify_for_cart(self,input_tree,data,labels) ->str :
        '''CART決策樹對一條數據進行分類'''
        # 對一條數據進行分類,得到最終的分類標籤
        if not self._is_tree(input_tree):
            return input_tree
        feature = list(input_tree.keys())[0]
        if feature in labels:
            sub_tree = input_tree[feature]
            sub_tree_feature = list(sub_tree.keys())[0]
            features_index = labels.index(feature)
            if data[features_index] == sub_tree_feature:
                label = self._classify_for_cart(sub_tree[sub_tree_feature]['is'],data,labels)
            else:
                label = self._classify_for_cart(sub_tree[sub_tree_feature]['not is'],data,labels)
            return label

    def _calc_shannon_ent(self,dataset):
        '''
        計算輸入數據集的信息熵
        信息熵的計算方法爲根據分類類別將數據分成X份,
        shannon_ent = -count(x_i)/count(data)*log(count(x_i)/count(data),2)
        :param dataset:
        :return:
        '''
        num = len(dataset)
        def zero():
            return 0
        labels_count = defaultdict(zero)
        for feature in dataset:
            currentLabel = feature[-1]
            labels_count[currentLabel] += 1
        shannon_ent = 0.0
        for key in labels_count:
            prob = float(labels_count[key]) / num
            shannon_ent -= prob * log(prob, 2)  # 在機器學習中底數常用2單位爲比特,在通信中常用e爲底單位是香農
        return shannon_ent

    def _split_data(self,dataset, axis, value):
        '''
        根據第axis個特徵的value值將數據切割開
        返回值是當axis列爲value的所有行,不包含第axis列
        '''
        retdataset = []
        for data_line in dataset:
            if data_line[axis] == value:
                reduceddata_line = data_line[:axis]
                reduceddata_line.extend(data_line[axis + 1:])
                retdataset.append(reduceddata_line)
        return retdataset

    def _split_data_for_cart(self,dataset, index, value):
        '''
        根據第axis個特徵的value值將數據切割開
        返回值是當axis列爲value的所有行,不包含第axis列
        '''
        aa = [i[index]  for i in dataset]
        retdataset = []
        for data_line in dataset:
            if data_line[index] == value:
                retdataset.append(data_line)
        return retdataset

    def _choose_best_feature_by_information_entirpy(self,dataset):
        ''' ID3 根據信息增益選擇特徵
        :param dataset:
        :return: 根據信息增益進行劃分數據最佳特徵的索引
        '''
        num_features = len(dataset[0]) - 1  # 最後一列是標籤
        base_entropy = self._calc_shannon_ent(dataset)  # 計算得到基礎的信息熵
        best_information_entrop = 0.0  # 最佳信息增益率,信息增益率越大,則使用特徵A劃分數據獲得純度越高
        best_feature_index = -1  # 最佳特徵索引
        for i in range(num_features):
            feature_list = [example[i] for example in dataset]
            unique_values = set(feature_list)
            entropy_by_feature = 0.0  # 利用當前特徵對數據進行劃分的信息熵
            for value in unique_values:
                # 對於每個特徵的每個取值分割數據集
                subdataset = self._split_data(dataset, i, value) # 當前特徵的
                prob = len(subdataset) / float(len(dataset)) # 當前特徵的這個值所佔比例
                # 信息增益的計算方法
                #
                #
                #                            |D^v|
                # Gain(D,A) = Ent(D) - \sum —————— Ent(D^v)
                #                             |D|
                # 其中Ent(D)就是數據固有的信息熵base_entropy,
                # |D^v| 是根據特徵A的第V個值的數據樣本個數,|D|是樣本總個數
                # Ent(D^v) 是根據特徵A的第V個值獲得的數據的信息熵 entropy_by_feature
                #
                entropy_by_feature += prob * self._calc_shannon_ent(subdataset)
            information_entrop = base_entropy - entropy_by_feature
            # print((i, information_entrop))
            if information_entrop > best_information_entrop:
                best_information_entrop = information_entrop
                best_feature_index = i
        return best_feature_index

    def _choose_best_feature_by_gain_ratio(self,dataset):
        ''' C4.5 根據信息增益率選擇特徵
        :param dataset: 
        :return: 根據信息增益率進行劃分數據最佳特徵的索引
        '''
        num_features = len(dataset[0]) - 1  # 最後一列是標籤
        base_entropy = self._calc_shannon_ent(dataset)  # 計算得到基礎的信息熵
        best_feature_index = -1  # 最佳特徵索引
        best_gain_ratio = 0.0
        information_entropy_list = []  # 信息增益list
        gain_ratio_list = []  # 信息增益率list
        for i in range(num_features):
            feature_list = [example[i] for example in dataset]
            unique_values = set(feature_list)
            entropy_by_feature = 0.0  # 利用當前特徵對數據進行劃分的信息熵
            intrinsic_value = 0.0  # 當前特徵的固有值
            for value in unique_values:
                # 對於每個特徵的每個取值分割數據集
                subdataset = self._split_data(dataset, i, value) # 當前特徵的
                prob = len(subdataset) / float(len(dataset)) # 當前特徵的這個值所佔比例
                # 信息增益率的計算方法
                #                 Gain(D,A)     # 對於每個特徵,計算利用該特徵對數據進行劃分的信息熵
                # gain_ratio =   -----------
                #               intrinsic_value
                #
                #                            |D^v|
                # Gain(D,A) = Ent(D) - \sum —————— Ent(D^v)
                #                             |D|
                # intrinsic_value = \sum prob*log(prob,2)
                # 其中Ent(D)就是數據固有的信息熵base_entropy, 
                # |D^v| 是根據特徵A的第V個值的數據樣本個數,|D|是樣本總個數
                # Ent(D^v) 是根據特徵A的第V個值獲得的數據的信息熵 entropy_by_feature
                # 信息增益率要除以固有值,所以對於屬性值較少的特徵有傾向,
                # C4.5算法不是直接選擇信息增益率最高的,而是先選擇信息增益高於平均值的特徵,在從這些特徵中找信息增益率高的
                entropy_by_feature += prob * self._calc_shannon_ent(subdataset)
                intrinsic_value -= prob * log(prob, 2)  
            if (intrinsic_value == 0):
                continue
            gain_ratio = (base_entropy - entropy_by_feature) / intrinsic_value
            information_entropy_list.append(base_entropy - entropy_by_feature)
            gain_ratio_list.append(gain_ratio)
        try:
            mean_information_entropy = sum(information_entropy_list)/len(information_entropy_list)
        except :
            return -1
        for index,value in enumerate(information_entropy_list):
            if value > mean_information_entropy and gain_ratio_list[index] > best_gain_ratio:
                best_gain_ratio = gain_ratio_list[index]
                best_feature_index = index
        return best_feature_index

    def _count_most_label(self,label_list)->str:
        '''
        統計輸入的所有標籤中數量最多的一個標籤
        :param label_list:
        :return:
        '''
        def zero():
            return 0
        label_count = defaultdict(zero())
        for vote in label_list:
            label_count[vote] += 1
        label_count = sorted(label_count.items(), key=lambda x:x[1], reverse=True)
        return label_count[0][0]

    def _calc_ginii(self,dataSet):
        '''
        計算基尼指數
        :param dataSet:數據集
        :return: 計算結果
        '''
        data_len = len(dataSet)
        labelCounts = {}
        for featVec in dataSet:  # 遍歷每個實例,統計標籤的頻數
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        gini = 1.0
        for key in labelCounts:
            prob = float(labelCounts[key]) / data_len
            gini -= prob * prob  # 以2爲底的對數
        return gini

    def _calc_gini_by_value_of_feature(self,dataSet, feature, value):
        '''
        計算給定特徵下的基尼指數
        :param dataSet:數據集
        :param feature:特徵維度
        :param value:該特徵變量所取的值
        :return: 計算結果
        '''
        D0 = []
        D1 = []
        # 根據特徵劃分數據
        for featVec in dataSet:
            if featVec[feature] == value:  # 按照特徵feature是否是value來劃分數據,構造二分類樹
                D0.append(featVec)
            else:
                D1.append(featVec)
        # 基尼指數的計算方法:
        # gini = 1- \sum p**2
        # p 爲數據內類別所佔總數據比例
        # 在CART中使用的二叉樹,即對於某個特徵A的某個指標V,按照是否來劃分數據集,找到gini指數最小的特徵V和指標V
        Gini = len(D0) / len(dataSet) * self._calc_ginii(D0) + len(D1) / len(dataSet) * self._calc_ginii(D1)
        return Gini

    def _choose_best_feature_by_gini(self,dataSet):
        features_num = len(dataSet[0]) - 1
        min_gini = float(sys.maxsize)
        best_feature_index = 0
        best_feature_value = ''
        for feature_index in range(features_num):
            featList = [example[feature_index] for example in dataSet]
            unique_values = set(featList)
            for value in unique_values:
                this_gini = self._calc_gini_by_value_of_feature(dataSet, feature_index, value)
                if this_gini < min_gini and str(feature_index)+str(value) not in self.feature_value:
                    best_feature_index = feature_index
                    min_gini = this_gini
                    best_feature_value = value
        return best_feature_index,best_feature_value

    def _create_mulit_fork_tree(self, dataset, LABELS) -> dict:
        '''
        構造多叉決策樹 ID3 C4.5
        依次找到最佳的分割特徵,對於該特徵的每個取值,找到特徵爲該取值的數據和子標籤迭代構造,直到數據只有了一個類別
        '''
        labels = LABELS[:]  # list是個可變對象
        label_list = [example[-1] for example in dataset]
        if label_list.count(label_list[0]) == len(label_list):
            return label_list[0]  # 只有一個類別
        # if len(dataset[0]) == 1:  # 樣本只有一個特徵
        #     return self._count_most_label(label_list)

        # 最大信息增益率,可以修改爲 _choose_best_feature_by_information_entirpy  使信息增益
        best_feature_index = self._choose_best_feature_by_gain_ratio(dataset)
        best_feature = labels[best_feature_index]
        if (best_feature_index == -1):
            return label_list[0]
        tree = {best_feature: {}}
        del (labels[best_feature_index])
        feature_values = [example[best_feature_index] for example in dataset]
        unique_values = set(feature_values)
        for value in unique_values:
            sub_labels = labels[:]
            tree[best_feature][value] = self._create_mulit_fork_tree(
                self._split_data(dataset, best_feature_index, value), sub_labels)
        return tree

    def _create_two_fork_tree(self, data, labels):
        '''
        創建CART決策樹
        :param: data:訓練數據集
        :return: labels:所有的類標籤
        '''
        labels_list = [example[-1] for example in data]
        if len(labels_list) == 1 or labels_list.count(labels_list[0]) == len(labels_list):
            return labels_list[0]  # 第一個遞歸結束條件:所有的類標籤完全相同
        if len(data) == 1:
            return data[0][-1]
        # if len(data[0]) == 1:
        #     return self._leaf_max_label(labels_list)  # 第二個遞歸結束條件:用完了所有特徵

        best_feature_index,best_feature_value = self._choose_best_feature_by_gini(data)  # 最優劃分特徵
        if best_feature_value == '':
            return data[0][-1]
        self.feature_value.append(str(best_feature_index)+str(best_feature_value))
        best_feature = labels[best_feature_index]
        tree = {best_feature: {best_feature_value:{}}}  # 使用字典類型儲存樹的信息
        is_data = self._split_data_for_cart(data,best_feature_index,best_feature_value)
        not_data = [example for example in data if example not in is_data]
        if len(is_data) == len(data):
            tree[best_feature][best_feature_value]['is'] = Counter(labels_list).most_common(1)[0][0]
        else:
            tree[best_feature][best_feature_value]['is'] = self._create_two_fork_tree(is_data,labels)
            tree[best_feature][best_feature_value]['not is'] = self._create_two_fork_tree(not_data,labels)

        return tree

    @staticmethod
    def load_data():
        # 加載數據
        # 數據集是一個二維list,每一行是一個元素,最後一位是標籤
        dataset = [[1, 1, 0, 0, 0, 'A'], [0, 3, 2, 1, 0, 'C'], [2, 3, 2, 1, 0, 'A'], [1, 3, 0, 0, 0, 'C'], [1, 2, 1, 1, 0, 'C'],
                 [0, 2, 0, 0, 0, 'C'], [1, 0, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'A'], [1, 2, 0, 0, 0, 'C'], [0, 3, 2, 1, 1, 'C'],
                 [1, 2, 0, 0, 0, 'C'], [2, 2, 1, 0, 0, 'A'], [1, 1, 1, 0, 0, 'A'], [1, 2, 0, 0, 0, 'C'], [1, 2, 0, 0, 0, 'A'],
                 [2, 1, 2, 0, 1, 'C'], [1, 3, 2, 0, 0, 'C'], [1, 2, 2, 0, 0, 'A'], [2, 2, 2, 0, 0, 'A'], [2, 1, 0, 0, 0, 'C'],
                 [2, 2, 1, 0, 0, 'A'], [2, 2, 0, 0, 0, 'C'], [2, 0, 0, 0, 0, 'A'], [0, 3, 1, 0, 0, 'A'], [0, 2, 0, 0, 0, 'C'],
                 [2, 0, 0, 0, 0, 'A'], [2, 2, 1, 1, 0, 'A'], [1, 3, 1, 1, 1, 'C'], [0, 2, 1, 0, 0, 'A'], [2, 0, 1, 1, 0, 'C'],
                 [1, 1, 0, 0, 0, 'A'], [2, 3, 1, 1, 0, 'A'], [2, 1, 2, 1, 1, 'C'], [0, 1, 2, 0, 1, 'A'], [2, 1, 1, 1, 0, 'C'],
                 [2, 0, 1, 1, 0, 'C'], [2, 0, 2, 1, 0, 'C'], [2, 2, 1, 1, 0, 'A'], [2, 3, 1, 1, 0, 'A'], [2, 2, 0, 0, 0, 'C'],
                 [1, 1, 2, 0, 1, 'C'], [2, 1, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [0, 1, 2, 1, 1, 'C'], [1, 3, 0, 0, 0, 'A'],
                 [2, 1, 0, 0, 0, 'C'], [0, 1, 2, 0, 0, 'C'], [0, 3, 0, 0, 0, 'A'], [1, 3, 0, 0, 0, 'A'], [2, 0, 0, 0, 0, 'C'],
                 [0, 3, 1, 0, 1, 'A'], [2, 1, 0, 0, 0, 'A'], [1, 1, 0, 0, 0, 'A'], [2, 2, 0, 0, 0, 'A'], [1, 3, 0, 0, 0, 'A'],
                 [1, 3, 0, 0, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [0, 2, 1, 0, 0, 'A'], [0, 3, 0, 0, 0, 'A'], [0, 1, 1, 1, 0, 'C'],
                 [0, 2, 2, 0, 0, 'C'], [2, 2, 0, 0, 0, 'A'], [0, 2, 0, 0, 0, 'A'], [1, 1, 0, 0, 1, 'C'], [0, 3, 2, 0, 0, 'C'],
                 [0, 3, 2, 1, 1, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 3, 0, 0, 0, 'A'], [2, 0, 0, 0, 0, 'A'],
                 [1, 1, 0, 0, 0, 'C'], [1, 2, 2, 1, 0, 'A'], [1, 2, 0, 0, 0, 'C'], [0, 2, 1, 1, 0, 'A'], [2, 1, 0, 0, 0, 'A'],
                 [1, 2, 2, 0, 0, 'A'], [2, 2, 0, 0, 0, 'A'], [1, 3, 2, 0, 0, 'A'], [1, 2, 1, 1, 0, 'A'], [1, 3, 2, 0, 0, 'C'],
                 [2, 1, 0, 0, 0, 'C'], [0, 3, 0, 1, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 1, 1, 0, 0, 'A'],
                 [2, 1, 0, 0, 0, 'C'], [1, 3, 2, 0, 1, 'C'], [1, 2, 1, 1, 0, 'C'], [0, 3, 2, 1, 1, 'A'], [2, 2, 0, 0, 0, 'A'],
                 [0, 3, 2, 0, 1, 'A'], [0, 3, 2, 0, 1, 'C'], [0, 3, 2, 0, 1, 'A'], [2, 1, 0, 0, 0, 'A'], [0, 3, 2, 0, 0, 'A'],
                 [2, 3, 0, 0, 0, 'A'], [1, 2, 0, 0, 0, 'A'], [1, 1, 0, 0, 0, 'A'], [1, 3, 0, 0, 0, 'C'], [1, 2, 1, 1, 0, 'C'],
                 [0, 1, 2, 0, 1, 'A'], [1, 1, 2, 1, 1, 'C'], [1, 1, 2, 1, 1, 'C'], [0, 3, 1, 0, 1, 'A'], [0, 2, 0, 0, 0, 'A'],
                 [1, 1, 2, 0, 1, 'C'], [1, 2, 0, 0, 0, 'A'], [2, 1, 1, 1, 0, 'A'], [1, 3, 2, 1, 0, 'A'], [1, 3, 0, 0, 0, 'C'],
                 [2, 0, 0, 0, 0, 'A'], [1, 0, 2, 0, 1, 'A'], [1, 1, 0, 0, 0, 'C'], [1, 1, 2, 0, 1, 'A'], [0, 2, 1, 1, 0, 'C'],
                 [1, 1, 1, 0, 1, 'C'], [0, 3, 0, 0, 0, 'A'], [2, 0, 2, 0, 0, 'A'], [1, 1, 2, 0, 1, 'C'], [0, 0, 1, 1, 0, 'C'],
                 [0, 2, 0, 0, 0, 'A'], [1, 1, 2, 1, 1, 'C'], [1, 3, 0, 0, 0, 'A'], [1, 2, 1, 0, 0, 'C'], [1, 2, 2, 0, 0, 'A'],
                 [2, 1, 1, 0, 0, 'A'], [0, 2, 0, 0, 0, 'A'], [1, 0, 0, 0, 0, 'C'], [0, 2, 0, 0, 0, 'C'], [0, 1, 2, 0, 1, 'A'],
                 [0, 3, 0, 1, 1, 'A'], [1, 3, 2, 0, 0, 'A'], [1, 1, 1, 1, 1, 'A'], [0, 2, 0, 1, 0, 'A'], [0, 3, 2, 1, 1, 'A'],
                 [1, 1, 2, 0, 1, 'C'], [1, 2, 1, 1, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [1, 2, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'C'],
                 [1, 3, 2, 0, 1, 'C'], [2, 3, 0, 0, 0, 'C'], [0, 3, 1, 0, 0, 'A'], [1, 3, 2, 1, 0, 'A'], [0, 2, 2, 0, 0, 'A'],
                 [1, 3, 0, 0, 0, 'C'], [1, 0, 0, 0, 0, 'A'], [1, 2, 1, 1, 0, 'A'], [1, 3, 1, 0, 0, 'A'], [0, 3, 2, 0, 1, 'C'],
                 [2, 0, 0, 0, 0, 'C'], [1, 1, 1, 1, 0, 'A'], [0, 3, 2, 1, 1, 'C'], [0, 2, 2, 0, 0, 'A'], [2, 0, 1, 1, 0, 'C'],
                 [2, 0, 1, 1, 0, 'C'], [2, 2, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [2, 0, 1, 1, 0, 'A'],
                 [0, 3, 0, 1, 0, 'A'], [0, 3, 1, 1, 0, 'C'], [1, 3, 1, 0, 0, 'A'], [2, 2, 0, 0, 0, 'A'], [0, 3, 0, 0, 0, 'A'],
                 [1, 3, 2, 0, 1, 'C'], [2, 1, 1, 0, 0, 'C'], [2, 2, 2, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [0, 2, 1, 0, 0, 'C'],
                 [2, 2, 0, 0, 0, 'C'], [0, 3, 2, 0, 1, 'C'], [1, 2, 0, 0, 0, 'C'], [1, 3, 1, 1, 1, 'C'], [2, 1, 2, 0, 0, 'C'],
                 [2, 1, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [0, 1, 0, 0, 1, 'C'], [1, 2, 2, 0, 0, 'C'], [0, 1, 2, 0, 1, 'C'],
                 [0, 2, 1, 1, 1, 'C'], [0, 2, 1, 0, 0, 'C'], [0, 3, 2, 1, 0, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'],
                 [1, 3, 2, 0, 1, 'C'], [2, 1, 0, 0, 0, 'C'], [1, 1, 0, 0, 0, 'C'], [2, 1, 1, 0, 0, 'C'], [0, 2, 1, 1, 0, 'C'],
                 [1, 1, 0, 0, 0, 'C'], [2, 1, 2, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 2, 1, 0, 0, 'C'], [1, 3, 0, 0, 0, 'C']]
        # 人工抽取的測試集
        test = [[1, 2, 1, 1, 0, 'C'], [1, 3, 2, 0, 0, 'C'], [2, 0, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'A'], [2, 3, 0, 0, 0, 'A'],
                [0, 2, 2, 0, 0, 'A'], [1, 3, 0, 0, 0, 'A'], [2, 1, 2, 0, 1, 'C'], [1, 1, 0, 0, 0, 'C'], [1, 0, 2, 0, 1, 'A'],
                [2, 1, 0, 0, 0, 'C'], [2, 2, 2, 0, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [0, 3, 2, 1, 1, 'C'], [0, 3, 0, 0, 0, 'A'],
                [2, 0, 2, 0, 0, 'A'], [2, 0, 0, 0, 0, 'A'], [1, 1, 2, 0, 1, 'A'], [2, 1, 0, 0, 0, 'C'], [0, 0, 1, 1, 0, 'C'],
                [1, 2, 1, 0, 0, 'C'], [0, 2, 2, 0, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'],
                [1, 1, 1, 0, 1, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [1, 1, 0, 0, 0, 'A'], [1, 1, 2, 1, 1, 'C'],
                [1, 1, 2, 0, 1, 'C'], [1, 2, 2, 0, 0, 'A'], [0, 2, 0, 0, 0, 'C'], [1, 2, 0, 0, 0, 'C'], [1, 0, 0, 0, 0, 'C'],
                [0, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [0, 2, 1, 1, 0, 'A']]
        labels = ['age','profession','born','location','insurance']
        return dataset,labels,test

    def split_trian_test(self,dataset,test_size):
        train_data,test_data = train_test_split(dataset,test_size = test_size)
        return train_data,test_data

    def pruning_ccp_for_mulit_fork(self,input_tree,data,data_labels):
        # ccp代價複雜度剪枝
        tree_list = [input_tree]
        error_list = [self.clac_err(input_tree,data,data_labels)]
        def _pruning_cpp(input_tree,data,labels):
            self.min_alpha = sys.maxsize
            self.min_tree = input_tree

            if not self._is_tree(input_tree) :
                return
            feature = list(input_tree.keys())[0]
            feature_index = labels.index(feature)
            pt = float(self._get_data_num(data,index=feature_index,feature=feature)) / float(num_data)  # 節點t上的數據佔所有數據的比例
            NT = self._get_leaf_num(input_tree)  # 當前節點的葉子數
            rt = self.clac_err(input_tree,data,labels)  # 當前子樹的誤差率
            Rt = rt*pt
            RTt = self._get_error_value(input_tree,data)  # 子樹的誤差代價
            lambda_ = 0.01  # 平滑係數
            alpha = (Rt - RTt +lambda_) /(NT -1+lambda_)
            if alpha < self.min_alpha:
                self.min_alpha = alpha
                self.min_tree = input_tree
            sub_tree = input_tree[feature]  #當前的子樹
            for key in list(sub_tree.keys()):
                if self._is_tree(sub_tree[key]):
                    _pruning_cpp(sub_tree[key],data,labels)
        def _find_tree(parent, child, label):
            # 在整體中找到子樹的位置
            if parent == child:
                parent = label
                return parent
            else:
                sub_tree = parent[list(parent.keys())[0]]

                for j in sub_tree.keys():
                    if self._is_tree(sub_tree[j]):
                        sub_tree[j] = _find_tree(sub_tree[j], child, label)
            return parent
        def _del_min_tree():
            # 刪除子樹
            label = self._leaf_max_label(self.min_tree)
            tree = _find_tree(input_tree,self.min_tree,label)
            return tree

        for pruning_num_i in range(self.config['pruning_num']):
            num_data = len(data)

            _pruning_cpp(input_tree,data,data_labels)
            tree = _del_min_tree()  # 剪枝後的樹
            input_tree = tree
            if not self._is_tree(input_tree):
                print('剪枝{}次之後的決策樹\n'.format(pruning_num_i+1),{'all':input_tree})
                print('剪枝{}次的決策樹測試誤差是'.format(pruning_num_i+1),self.clac_err(tree,data,data_labels))
                tree_list.append({'all':input_tree})
                error_list.append(self.clac_err(tree,data,data_labels))
            else:
                print('剪枝{}次之後的決策樹\n'.format(pruning_num_i+1),input_tree)
                print('剪枝{}次的決策樹測試誤差是'.format(pruning_num_i+1),self.clac_err(tree,data,data_labels))
                tree_list.append(tree)
                error_list.append(self.clac_err(tree,data,data_labels))
        min_error = sys.maxsize
        min_error_tree = None
        min_index = 0
        for index,value in enumerate(error_list):
            if value < min_error:
                min_error = value
                min_error_tree = tree_list[index]
                min_index = index
        print('原生決策樹經過{}次剪枝之後測試誤差最小,誤差爲{},則該決策樹爲最終的決策樹'.format(min_index,min_error))
        print(min_error_tree)
        return min_error_tree

    def clac_err(self,tree,data,labels):
        # 對所有數據計算決策樹誤差
        error = 0.0
        for i in range(len(data)):
            if self.config['tree_fork'] == 2:
                if self._classify_for_cart(tree,data[i],labels) != data[i][-1]:
                    error += 1
            else:
                if self._classify(tree,data[i],labels) != data[i][-1]:
                    error += 1
        return float(error) / float(len(data))

    def create_tree(self,data,lables):
        if self.config['tree_fork'] == 2:
            self.feature_value = []
            tree = self._create_two_fork_tree(data,labels)
        elif self.config['tree_fork'] == 3:
            tree = self._create_mulit_fork_tree(data,labels)
        else:
            tree = self._create_mulit_fork_tree(data,labels)
        return tree

if __name__== '__main__':

    config = {
        'tree_fork': 2,  # 決策樹分叉數量,二叉樹爲CART,多叉樹爲c4.5,也可修改爲id3,默認是多叉
        'prun_type': '',   # 後剪枝類型,可選類型 CPP,
        'pruning_num': 20,  # 剪枝次數
    }
    dataset,labels,test_data = DecisionTree.load_data()  # 加載數據,獲得數據中的特徵名稱
    decision_tree = DecisionTree(config)  # 構造決策樹實例
    train_data,test_data = decision_tree.split_trian_test(dataset,0.3)  # 切割數據
    
    tree = decision_tree.create_tree(train_data,labels)  # 構造樹
    print('當前構造的決策樹是',tree)
    # decision_tree.show_tree(tree)
    print('未剪枝的決策樹的測試誤差是',decision_tree.clac_err(tree,test_data,labels))
    finally_tree = decision_tree.pruning_ccp_for_mulit_fork(tree,test_data,labels)  # 剪枝結束之後的樹

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章