


樹模型是在日常工作中使用頻率最高的模型之一,因爲其較好的模型效果與良好的可解釋性經常作爲baseline模型使用,在平時使用時經常使用sklearn庫使用,最近遇到了一個單子需要完全手動實現決策樹相關內容,在此將其記錄下來,如果有同學在做相關內容請與我郵件聯繫[email protected]









from math import log
from collections import defaultdict,Counter
from sklearn.model_selection import train_test_split
import sys
import matplotlib.pyplot as plt
# import tree

class DecisionTree(object):
    Email:[email protected]

    def __init__(self,config):
        self.config = config

    def _get_data_num(self,data,index,feature):
        # 該子樹下有多少個訓練數據
        num = 0
        for x in data:
            if x[index] == feature:
                num +=1
        return num

    def _get_leaf_num(self,tree):
        # 該子樹下有多少個葉子
        num = 0
        key_list = list(tree.keys())
        for i in key_list:
            except AttributeError:  # 葉子的最後一個不是字典沒有keys()
                num += 1
                return num
        return num

    def _get_error_value(self,tree,data):
        next_features = list(tree.keys())
        num = 0
        for next_feature in next_features:
            next_tree = tree[next_feature]
            leaf_num = self._get_leaf_num(next_tree)
            error_value = 0
            for next_tree_key in list(next_tree.keys()):
                error_value += self.clac_err(next_tree[next_tree_key],data,labels)
            num += (float(leaf_num)/len(data)) * error_value
        return num

    def _leaf_max_label(self,tree):
        # 得到該子樹上的最多標籤作爲葉子標籤
        self.label_count = {}
        def __leaf_max_label(tree):
            if self._is_tree(tree):
                key = list(tree.keys())[0]
                sub_tree = tree[key]
                for i in list(sub_tree.keys()):
                if tree not in self.label_count.keys():
                        self.label_count[tree] = 1
                        self.label_count[tree] += 1
        self.label_count = sorted(self.label_count.items(), key= lambda x:x[1], reverse=True)  # 排序
        return self.label_count[0][0]

    def _is_tree(self,obj):
        return isinstance(obj,dict)

    def _classify(self,input_tree,data,labels) ->str :
        # 對一條數據進行分類,得到最終的分類標籤
        if not self._is_tree(input_tree):
            return input_tree
        feature = list(input_tree.keys())[0]
        firtdict = input_tree[feature]
        features_index = labels.index(feature)
        key = data[features_index]
                firtdict_or_value = firtdict[key]
        except :
                return firtdict[list(firtdict.keys())[0]]
        if self._is_tree(firtdict_or_value):
                label = self._classify(firtdict_or_value, data,labels)
                label = firtdict_or_value
        return label

    def _classify_for_cart(self,input_tree,data,labels) ->str :
        # 對一條數據進行分類,得到最終的分類標籤
        if not self._is_tree(input_tree):
            return input_tree
        feature = list(input_tree.keys())[0]
        if feature in labels:
            sub_tree = input_tree[feature]
            sub_tree_feature = list(sub_tree.keys())[0]
            features_index = labels.index(feature)
            if data[features_index] == sub_tree_feature:
                label = self._classify_for_cart(sub_tree[sub_tree_feature]['is'],data,labels)
                label = self._classify_for_cart(sub_tree[sub_tree_feature]['not is'],data,labels)
            return label

    def _calc_shannon_ent(self,dataset):
        shannon_ent = -count(x_i)/count(data)*log(count(x_i)/count(data),2)
        :param dataset:
        num = len(dataset)
        def zero():
            return 0
        labels_count = defaultdict(zero)
        for feature in dataset:
            currentLabel = feature[-1]
            labels_count[currentLabel] += 1
        shannon_ent = 0.0
        for key in labels_count:
            prob = float(labels_count[key]) / num
            shannon_ent -= prob * log(prob, 2)  # 在機器學習中底數常用2單位爲比特,在通信中常用e爲底單位是香農
        return shannon_ent

    def _split_data(self,dataset, axis, value):
        retdataset = []
        for data_line in dataset:
            if data_line[axis] == value:
                reduceddata_line = data_line[:axis]
                reduceddata_line.extend(data_line[axis + 1:])
        return retdataset

    def _split_data_for_cart(self,dataset, index, value):
        aa = [i[index]  for i in dataset]
        retdataset = []
        for data_line in dataset:
            if data_line[index] == value:
        return retdataset

    def _choose_best_feature_by_information_entirpy(self,dataset):
        ''' ID3 根據信息增益選擇特徵
        :param dataset:
        :return: 根據信息增益進行劃分數據最佳特徵的索引
        num_features = len(dataset[0]) - 1  # 最後一列是標籤
        base_entropy = self._calc_shannon_ent(dataset)  # 計算得到基礎的信息熵
        best_information_entrop = 0.0  # 最佳信息增益率,信息增益率越大,則使用特徵A劃分數據獲得純度越高
        best_feature_index = -1  # 最佳特徵索引
        for i in range(num_features):
            feature_list = [example[i] for example in dataset]
            unique_values = set(feature_list)
            entropy_by_feature = 0.0  # 利用當前特徵對數據進行劃分的信息熵
            for value in unique_values:
                # 對於每個特徵的每個取值分割數據集
                subdataset = self._split_data(dataset, i, value) # 當前特徵的
                prob = len(subdataset) / float(len(dataset)) # 當前特徵的這個值所佔比例
                # 信息增益的計算方法
                #                            |D^v|
                # Gain(D,A) = Ent(D) - \sum —————— Ent(D^v)
                #                             |D|
                # 其中Ent(D)就是數據固有的信息熵base_entropy,
                # |D^v| 是根據特徵A的第V個值的數據樣本個數,|D|是樣本總個數
                # Ent(D^v) 是根據特徵A的第V個值獲得的數據的信息熵 entropy_by_feature
                entropy_by_feature += prob * self._calc_shannon_ent(subdataset)
            information_entrop = base_entropy - entropy_by_feature
            # print((i, information_entrop))
            if information_entrop > best_information_entrop:
                best_information_entrop = information_entrop
                best_feature_index = i
        return best_feature_index

    def _choose_best_feature_by_gain_ratio(self,dataset):
        ''' C4.5 根據信息增益率選擇特徵
        :param dataset: 
        :return: 根據信息增益率進行劃分數據最佳特徵的索引
        num_features = len(dataset[0]) - 1  # 最後一列是標籤
        base_entropy = self._calc_shannon_ent(dataset)  # 計算得到基礎的信息熵
        best_feature_index = -1  # 最佳特徵索引
        best_gain_ratio = 0.0
        information_entropy_list = []  # 信息增益list
        gain_ratio_list = []  # 信息增益率list
        for i in range(num_features):
            feature_list = [example[i] for example in dataset]
            unique_values = set(feature_list)
            entropy_by_feature = 0.0  # 利用當前特徵對數據進行劃分的信息熵
            intrinsic_value = 0.0  # 當前特徵的固有值
            for value in unique_values:
                # 對於每個特徵的每個取值分割數據集
                subdataset = self._split_data(dataset, i, value) # 當前特徵的
                prob = len(subdataset) / float(len(dataset)) # 當前特徵的這個值所佔比例
                # 信息增益率的計算方法
                #                 Gain(D,A)     # 對於每個特徵,計算利用該特徵對數據進行劃分的信息熵
                # gain_ratio =   -----------
                #               intrinsic_value
                #                            |D^v|
                # Gain(D,A) = Ent(D) - \sum —————— Ent(D^v)
                #                             |D|
                # intrinsic_value = \sum prob*log(prob,2)
                # 其中Ent(D)就是數據固有的信息熵base_entropy, 
                # |D^v| 是根據特徵A的第V個值的數據樣本個數,|D|是樣本總個數
                # Ent(D^v) 是根據特徵A的第V個值獲得的數據的信息熵 entropy_by_feature
                # 信息增益率要除以固有值,所以對於屬性值較少的特徵有傾向,
                # C4.5算法不是直接選擇信息增益率最高的,而是先選擇信息增益高於平均值的特徵,在從這些特徵中找信息增益率高的
                entropy_by_feature += prob * self._calc_shannon_ent(subdataset)
                intrinsic_value -= prob * log(prob, 2)  
            if (intrinsic_value == 0):
            gain_ratio = (base_entropy - entropy_by_feature) / intrinsic_value
            information_entropy_list.append(base_entropy - entropy_by_feature)
            mean_information_entropy = sum(information_entropy_list)/len(information_entropy_list)
        except :
            return -1
        for index,value in enumerate(information_entropy_list):
            if value > mean_information_entropy and gain_ratio_list[index] > best_gain_ratio:
                best_gain_ratio = gain_ratio_list[index]
                best_feature_index = index
        return best_feature_index

    def _count_most_label(self,label_list)->str:
        :param label_list:
        def zero():
            return 0
        label_count = defaultdict(zero())
        for vote in label_list:
            label_count[vote] += 1
        label_count = sorted(label_count.items(), key=lambda x:x[1], reverse=True)
        return label_count[0][0]

    def _calc_ginii(self,dataSet):
        :param dataSet:數據集
        :return: 計算結果
        data_len = len(dataSet)
        labelCounts = {}
        for featVec in dataSet:  # 遍歷每個實例,統計標籤的頻數
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        gini = 1.0
        for key in labelCounts:
            prob = float(labelCounts[key]) / data_len
            gini -= prob * prob  # 以2爲底的對數
        return gini

    def _calc_gini_by_value_of_feature(self,dataSet, feature, value):
        :param dataSet:數據集
        :param feature:特徵維度
        :param value:該特徵變量所取的值
        :return: 計算結果
        D0 = []
        D1 = []
        # 根據特徵劃分數據
        for featVec in dataSet:
            if featVec[feature] == value:  # 按照特徵feature是否是value來劃分數據,構造二分類樹
        # 基尼指數的計算方法:
        # gini = 1- \sum p**2
        # p 爲數據內類別所佔總數據比例
        # 在CART中使用的二叉樹,即對於某個特徵A的某個指標V,按照是否來劃分數據集,找到gini指數最小的特徵V和指標V
        Gini = len(D0) / len(dataSet) * self._calc_ginii(D0) + len(D1) / len(dataSet) * self._calc_ginii(D1)
        return Gini

    def _choose_best_feature_by_gini(self,dataSet):
        features_num = len(dataSet[0]) - 1
        min_gini = float(sys.maxsize)
        best_feature_index = 0
        best_feature_value = ''
        for feature_index in range(features_num):
            featList = [example[feature_index] for example in dataSet]
            unique_values = set(featList)
            for value in unique_values:
                this_gini = self._calc_gini_by_value_of_feature(dataSet, feature_index, value)
                if this_gini < min_gini and str(feature_index)+str(value) not in self.feature_value:
                    best_feature_index = feature_index
                    min_gini = this_gini
                    best_feature_value = value
        return best_feature_index,best_feature_value

    def _create_mulit_fork_tree(self, dataset, LABELS) -> dict:
        構造多叉決策樹 ID3 C4.5
        labels = LABELS[:]  # list是個可變對象
        label_list = [example[-1] for example in dataset]
        if label_list.count(label_list[0]) == len(label_list):
            return label_list[0]  # 只有一個類別
        # if len(dataset[0]) == 1:  # 樣本只有一個特徵
        #     return self._count_most_label(label_list)

        # 最大信息增益率,可以修改爲 _choose_best_feature_by_information_entirpy  使信息增益
        best_feature_index = self._choose_best_feature_by_gain_ratio(dataset)
        best_feature = labels[best_feature_index]
        if (best_feature_index == -1):
            return label_list[0]
        tree = {best_feature: {}}
        del (labels[best_feature_index])
        feature_values = [example[best_feature_index] for example in dataset]
        unique_values = set(feature_values)
        for value in unique_values:
            sub_labels = labels[:]
            tree[best_feature][value] = self._create_mulit_fork_tree(
                self._split_data(dataset, best_feature_index, value), sub_labels)
        return tree

    def _create_two_fork_tree(self, data, labels):
        :param: data:訓練數據集
        :return: labels:所有的類標籤
        labels_list = [example[-1] for example in data]
        if len(labels_list) == 1 or labels_list.count(labels_list[0]) == len(labels_list):
            return labels_list[0]  # 第一個遞歸結束條件:所有的類標籤完全相同
        if len(data) == 1:
            return data[0][-1]
        # if len(data[0]) == 1:
        #     return self._leaf_max_label(labels_list)  # 第二個遞歸結束條件:用完了所有特徵

        best_feature_index,best_feature_value = self._choose_best_feature_by_gini(data)  # 最優劃分特徵
        if best_feature_value == '':
            return data[0][-1]
        best_feature = labels[best_feature_index]
        tree = {best_feature: {best_feature_value:{}}}  # 使用字典類型儲存樹的信息
        is_data = self._split_data_for_cart(data,best_feature_index,best_feature_value)
        not_data = [example for example in data if example not in is_data]
        if len(is_data) == len(data):
            tree[best_feature][best_feature_value]['is'] = Counter(labels_list).most_common(1)[0][0]
            tree[best_feature][best_feature_value]['is'] = self._create_two_fork_tree(is_data,labels)
            tree[best_feature][best_feature_value]['not is'] = self._create_two_fork_tree(not_data,labels)

        return tree

    def load_data():
        # 加載數據
        # 數據集是一個二維list,每一行是一個元素,最後一位是標籤
        dataset = [[1, 1, 0, 0, 0, 'A'], [0, 3, 2, 1, 0, 'C'], [2, 3, 2, 1, 0, 'A'], [1, 3, 0, 0, 0, 'C'], [1, 2, 1, 1, 0, 'C'],
                 [0, 2, 0, 0, 0, 'C'], [1, 0, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'A'], [1, 2, 0, 0, 0, 'C'], [0, 3, 2, 1, 1, 'C'],
                 [1, 2, 0, 0, 0, 'C'], [2, 2, 1, 0, 0, 'A'], [1, 1, 1, 0, 0, 'A'], [1, 2, 0, 0, 0, 'C'], [1, 2, 0, 0, 0, 'A'],
                 [2, 1, 2, 0, 1, 'C'], [1, 3, 2, 0, 0, 'C'], [1, 2, 2, 0, 0, 'A'], [2, 2, 2, 0, 0, 'A'], [2, 1, 0, 0, 0, 'C'],
                 [2, 2, 1, 0, 0, 'A'], [2, 2, 0, 0, 0, 'C'], [2, 0, 0, 0, 0, 'A'], [0, 3, 1, 0, 0, 'A'], [0, 2, 0, 0, 0, 'C'],
                 [2, 0, 0, 0, 0, 'A'], [2, 2, 1, 1, 0, 'A'], [1, 3, 1, 1, 1, 'C'], [0, 2, 1, 0, 0, 'A'], [2, 0, 1, 1, 0, 'C'],
                 [1, 1, 0, 0, 0, 'A'], [2, 3, 1, 1, 0, 'A'], [2, 1, 2, 1, 1, 'C'], [0, 1, 2, 0, 1, 'A'], [2, 1, 1, 1, 0, 'C'],
                 [2, 0, 1, 1, 0, 'C'], [2, 0, 2, 1, 0, 'C'], [2, 2, 1, 1, 0, 'A'], [2, 3, 1, 1, 0, 'A'], [2, 2, 0, 0, 0, 'C'],
                 [1, 1, 2, 0, 1, 'C'], [2, 1, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [0, 1, 2, 1, 1, 'C'], [1, 3, 0, 0, 0, 'A'],
                 [2, 1, 0, 0, 0, 'C'], [0, 1, 2, 0, 0, 'C'], [0, 3, 0, 0, 0, 'A'], [1, 3, 0, 0, 0, 'A'], [2, 0, 0, 0, 0, 'C'],
                 [0, 3, 1, 0, 1, 'A'], [2, 1, 0, 0, 0, 'A'], [1, 1, 0, 0, 0, 'A'], [2, 2, 0, 0, 0, 'A'], [1, 3, 0, 0, 0, 'A'],
                 [1, 3, 0, 0, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [0, 2, 1, 0, 0, 'A'], [0, 3, 0, 0, 0, 'A'], [0, 1, 1, 1, 0, 'C'],
                 [0, 2, 2, 0, 0, 'C'], [2, 2, 0, 0, 0, 'A'], [0, 2, 0, 0, 0, 'A'], [1, 1, 0, 0, 1, 'C'], [0, 3, 2, 0, 0, 'C'],
                 [0, 3, 2, 1, 1, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 3, 0, 0, 0, 'A'], [2, 0, 0, 0, 0, 'A'],
                 [1, 1, 0, 0, 0, 'C'], [1, 2, 2, 1, 0, 'A'], [1, 2, 0, 0, 0, 'C'], [0, 2, 1, 1, 0, 'A'], [2, 1, 0, 0, 0, 'A'],
                 [1, 2, 2, 0, 0, 'A'], [2, 2, 0, 0, 0, 'A'], [1, 3, 2, 0, 0, 'A'], [1, 2, 1, 1, 0, 'A'], [1, 3, 2, 0, 0, 'C'],
                 [2, 1, 0, 0, 0, 'C'], [0, 3, 0, 1, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 1, 1, 0, 0, 'A'],
                 [2, 1, 0, 0, 0, 'C'], [1, 3, 2, 0, 1, 'C'], [1, 2, 1, 1, 0, 'C'], [0, 3, 2, 1, 1, 'A'], [2, 2, 0, 0, 0, 'A'],
                 [0, 3, 2, 0, 1, 'A'], [0, 3, 2, 0, 1, 'C'], [0, 3, 2, 0, 1, 'A'], [2, 1, 0, 0, 0, 'A'], [0, 3, 2, 0, 0, 'A'],
                 [2, 3, 0, 0, 0, 'A'], [1, 2, 0, 0, 0, 'A'], [1, 1, 0, 0, 0, 'A'], [1, 3, 0, 0, 0, 'C'], [1, 2, 1, 1, 0, 'C'],
                 [0, 1, 2, 0, 1, 'A'], [1, 1, 2, 1, 1, 'C'], [1, 1, 2, 1, 1, 'C'], [0, 3, 1, 0, 1, 'A'], [0, 2, 0, 0, 0, 'A'],
                 [1, 1, 2, 0, 1, 'C'], [1, 2, 0, 0, 0, 'A'], [2, 1, 1, 1, 0, 'A'], [1, 3, 2, 1, 0, 'A'], [1, 3, 0, 0, 0, 'C'],
                 [2, 0, 0, 0, 0, 'A'], [1, 0, 2, 0, 1, 'A'], [1, 1, 0, 0, 0, 'C'], [1, 1, 2, 0, 1, 'A'], [0, 2, 1, 1, 0, 'C'],
                 [1, 1, 1, 0, 1, 'C'], [0, 3, 0, 0, 0, 'A'], [2, 0, 2, 0, 0, 'A'], [1, 1, 2, 0, 1, 'C'], [0, 0, 1, 1, 0, 'C'],
                 [0, 2, 0, 0, 0, 'A'], [1, 1, 2, 1, 1, 'C'], [1, 3, 0, 0, 0, 'A'], [1, 2, 1, 0, 0, 'C'], [1, 2, 2, 0, 0, 'A'],
                 [2, 1, 1, 0, 0, 'A'], [0, 2, 0, 0, 0, 'A'], [1, 0, 0, 0, 0, 'C'], [0, 2, 0, 0, 0, 'C'], [0, 1, 2, 0, 1, 'A'],
                 [0, 3, 0, 1, 1, 'A'], [1, 3, 2, 0, 0, 'A'], [1, 1, 1, 1, 1, 'A'], [0, 2, 0, 1, 0, 'A'], [0, 3, 2, 1, 1, 'A'],
                 [1, 1, 2, 0, 1, 'C'], [1, 2, 1, 1, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [1, 2, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'C'],
                 [1, 3, 2, 0, 1, 'C'], [2, 3, 0, 0, 0, 'C'], [0, 3, 1, 0, 0, 'A'], [1, 3, 2, 1, 0, 'A'], [0, 2, 2, 0, 0, 'A'],
                 [1, 3, 0, 0, 0, 'C'], [1, 0, 0, 0, 0, 'A'], [1, 2, 1, 1, 0, 'A'], [1, 3, 1, 0, 0, 'A'], [0, 3, 2, 0, 1, 'C'],
                 [2, 0, 0, 0, 0, 'C'], [1, 1, 1, 1, 0, 'A'], [0, 3, 2, 1, 1, 'C'], [0, 2, 2, 0, 0, 'A'], [2, 0, 1, 1, 0, 'C'],
                 [2, 0, 1, 1, 0, 'C'], [2, 2, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [2, 0, 1, 1, 0, 'A'],
                 [0, 3, 0, 1, 0, 'A'], [0, 3, 1, 1, 0, 'C'], [1, 3, 1, 0, 0, 'A'], [2, 2, 0, 0, 0, 'A'], [0, 3, 0, 0, 0, 'A'],
                 [1, 3, 2, 0, 1, 'C'], [2, 1, 1, 0, 0, 'C'], [2, 2, 2, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [0, 2, 1, 0, 0, 'C'],
                 [2, 2, 0, 0, 0, 'C'], [0, 3, 2, 0, 1, 'C'], [1, 2, 0, 0, 0, 'C'], [1, 3, 1, 1, 1, 'C'], [2, 1, 2, 0, 0, 'C'],
                 [2, 1, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [0, 1, 0, 0, 1, 'C'], [1, 2, 2, 0, 0, 'C'], [0, 1, 2, 0, 1, 'C'],
                 [0, 2, 1, 1, 1, 'C'], [0, 2, 1, 0, 0, 'C'], [0, 3, 2, 1, 0, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'],
                 [1, 3, 2, 0, 1, 'C'], [2, 1, 0, 0, 0, 'C'], [1, 1, 0, 0, 0, 'C'], [2, 1, 1, 0, 0, 'C'], [0, 2, 1, 1, 0, 'C'],
                 [1, 1, 0, 0, 0, 'C'], [2, 1, 2, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 2, 1, 0, 0, 'C'], [1, 3, 0, 0, 0, 'C']]
        # 人工抽取的測試集
        test = [[1, 2, 1, 1, 0, 'C'], [1, 3, 2, 0, 0, 'C'], [2, 0, 1, 1, 0, 'C'], [1, 2, 0, 0, 0, 'A'], [2, 3, 0, 0, 0, 'A'],
                [0, 2, 2, 0, 0, 'A'], [1, 3, 0, 0, 0, 'A'], [2, 1, 2, 0, 1, 'C'], [1, 1, 0, 0, 0, 'C'], [1, 0, 2, 0, 1, 'A'],
                [2, 1, 0, 0, 0, 'C'], [2, 2, 2, 0, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [0, 3, 2, 1, 1, 'C'], [0, 3, 0, 0, 0, 'A'],
                [2, 0, 2, 0, 0, 'A'], [2, 0, 0, 0, 0, 'A'], [1, 1, 2, 0, 1, 'A'], [2, 1, 0, 0, 0, 'C'], [0, 0, 1, 1, 0, 'C'],
                [1, 2, 1, 0, 0, 'C'], [0, 2, 2, 0, 0, 'A'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'],
                [1, 1, 1, 0, 1, 'C'], [2, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'A'], [1, 1, 0, 0, 0, 'A'], [1, 1, 2, 1, 1, 'C'],
                [1, 1, 2, 0, 1, 'C'], [1, 2, 2, 0, 0, 'A'], [0, 2, 0, 0, 0, 'C'], [1, 2, 0, 0, 0, 'C'], [1, 0, 0, 0, 0, 'C'],
                [0, 2, 1, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [2, 1, 0, 0, 0, 'C'], [0, 2, 1, 1, 0, 'A']]
        labels = ['age','profession','born','location','insurance']
        return dataset,labels,test

    def split_trian_test(self,dataset,test_size):
        train_data,test_data = train_test_split(dataset,test_size = test_size)
        return train_data,test_data

    def pruning_ccp_for_mulit_fork(self,input_tree,data,data_labels):
        # ccp代價複雜度剪枝
        tree_list = [input_tree]
        error_list = [self.clac_err(input_tree,data,data_labels)]
        def _pruning_cpp(input_tree,data,labels):
            self.min_alpha = sys.maxsize
            self.min_tree = input_tree

            if not self._is_tree(input_tree) :
            feature = list(input_tree.keys())[0]
            feature_index = labels.index(feature)
            pt = float(self._get_data_num(data,index=feature_index,feature=feature)) / float(num_data)  # 節點t上的數據佔所有數據的比例
            NT = self._get_leaf_num(input_tree)  # 當前節點的葉子數
            rt = self.clac_err(input_tree,data,labels)  # 當前子樹的誤差率
            Rt = rt*pt
            RTt = self._get_error_value(input_tree,data)  # 子樹的誤差代價
            lambda_ = 0.01  # 平滑係數
            alpha = (Rt - RTt +lambda_) /(NT -1+lambda_)
            if alpha < self.min_alpha:
                self.min_alpha = alpha
                self.min_tree = input_tree
            sub_tree = input_tree[feature]  #當前的子樹
            for key in list(sub_tree.keys()):
                if self._is_tree(sub_tree[key]):
        def _find_tree(parent, child, label):
            # 在整體中找到子樹的位置
            if parent == child:
                parent = label
                return parent
                sub_tree = parent[list(parent.keys())[0]]

                for j in sub_tree.keys():
                    if self._is_tree(sub_tree[j]):
                        sub_tree[j] = _find_tree(sub_tree[j], child, label)
            return parent
        def _del_min_tree():
            # 刪除子樹
            label = self._leaf_max_label(self.min_tree)
            tree = _find_tree(input_tree,self.min_tree,label)
            return tree

        for pruning_num_i in range(self.config['pruning_num']):
            num_data = len(data)

            tree = _del_min_tree()  # 剪枝後的樹
            input_tree = tree
            if not self._is_tree(input_tree):
        min_error = sys.maxsize
        min_error_tree = None
        min_index = 0
        for index,value in enumerate(error_list):
            if value < min_error:
                min_error = value
                min_error_tree = tree_list[index]
                min_index = index
        return min_error_tree

    def clac_err(self,tree,data,labels):
        # 對所有數據計算決策樹誤差
        error = 0.0
        for i in range(len(data)):
            if self.config['tree_fork'] == 2:
                if self._classify_for_cart(tree,data[i],labels) != data[i][-1]:
                    error += 1
                if self._classify(tree,data[i],labels) != data[i][-1]:
                    error += 1
        return float(error) / float(len(data))

    def create_tree(self,data,lables):
        if self.config['tree_fork'] == 2:
            self.feature_value = []
            tree = self._create_two_fork_tree(data,labels)
        elif self.config['tree_fork'] == 3:
            tree = self._create_mulit_fork_tree(data,labels)
            tree = self._create_mulit_fork_tree(data,labels)
        return tree

if __name__== '__main__':

    config = {
        'tree_fork': 2,  # 決策樹分叉數量,二叉樹爲CART,多叉樹爲c4.5,也可修改爲id3,默認是多叉
        'prun_type': '',   # 後剪枝類型,可選類型 CPP,
        'pruning_num': 20,  # 剪枝次數
    dataset,labels,test_data = DecisionTree.load_data()  # 加載數據,獲得數據中的特徵名稱
    decision_tree = DecisionTree(config)  # 構造決策樹實例
    train_data,test_data = decision_tree.split_trian_test(dataset,0.3)  # 切割數據
    tree = decision_tree.create_tree(train_data,labels)  # 構造樹
    # decision_tree.show_tree(tree)
    finally_tree = decision_tree.pruning_ccp_for_mulit_fork(tree,test_data,labels)  # 剪枝結束之後的樹

