【原創】決策樹python源碼實現(含預剪枝和後剪枝)

決策樹python源碼實現(含預剪枝和後剪枝)

一、說明

        所用的環境爲Ubuntu + python 3.6,在jupyter中運行。本文實現周志華《機器學習》西瓜書中的4.1 ~ 4.3中的決策樹算法(不含連續值、缺失值處理),對應李航《統計學習方法》的5.1 ~ 5.4節。畫圖工具參考《機器學習實戰》中的部分代碼,本文樹的生成代碼大部分由自己完成,部分思路可能與該書有差異。另外本文程序的效率應該比《機器學習實戰》的高,因爲該書上有很多逐個樣本的遍歷,在本文中則使用numpy直接操作向量實現。
         由於個人理解問題,並且未經過大量數據的測試,算法實現可能會存在問題或值得改進的地方,請遇到的朋友幫忙指出,共同學習!
         代碼中有足夠的註釋!

二、代碼實現

  1. 開始
import math
import numpy as np 
  1. 創建數據
# 創建數據集 備註 李航《統計學習方法》中表5.1 貸款申請數據數據
def createDataLH():
    data = np.array([['青年', '否', '否', '一般']])
    data = np.append(data, [['青年', '否', '否', '好']], axis = 0)
    data = np.append(data, [['青年', '是', '否', '好'] 
                            , ['青年', '是', '是', '一般']
                            , ['青年', '否', '否', '一般']
                            , ['中年', '否', '否', '一般']
                            , ['中年', '否', '否', '好']
                            , ['中年', '是', '是', '好']
                            , ['中年', '否', '是', '非常好']
                            , ['中年', '否', '是', '非常好']
                            , ['老年', '否', '是', '非常好']
                            , ['老年', '否', '是', '好']
                            , ['老年', '是', '否', '好']
                            , ['老年', '是', '否', '非常好']
                            , ['老年', '否', '否', '一般']
                           ], axis = 0)
    label = np.array(['否', '否', '是', '是', '否', '否', '否', '是', '是', '是', '是', '是', '是', '是', '否'])
    name = np.array(['年齡', '有工作', '有房子', '信貸情況'])
    return data, label, name

# 創建西瓜書數據集2.0
def createDataXG20():
    data = np.array([['青綠', '蜷縮', '濁響', '清晰', '凹陷', '硬滑']
                    , ['烏黑', '蜷縮', '沉悶', '清晰', '凹陷', '硬滑']
                    , ['烏黑', '蜷縮', '濁響', '清晰', '凹陷', '硬滑']
                    , ['青綠', '蜷縮', '沉悶', '清晰', '凹陷', '硬滑']
                    , ['淺白', '蜷縮', '濁響', '清晰', '凹陷', '硬滑']
                    , ['青綠', '稍蜷', '濁響', '清晰', '稍凹', '軟粘']
                    , ['烏黑', '稍蜷', '濁響', '稍糊', '稍凹', '軟粘']
                    , ['烏黑', '稍蜷', '濁響', '清晰', '稍凹', '硬滑']
                    , ['烏黑', '稍蜷', '沉悶', '稍糊', '稍凹', '硬滑']
                    , ['青綠', '硬挺', '清脆', '清晰', '平坦', '軟粘']
                    , ['淺白', '硬挺', '清脆', '模糊', '平坦', '硬滑']
                    , ['淺白', '蜷縮', '濁響', '模糊', '平坦', '軟粘']
                    , ['青綠', '稍蜷', '濁響', '稍糊', '凹陷', '硬滑']
                    , ['淺白', '稍蜷', '沉悶', '稍糊', '凹陷', '硬滑']
                    , ['烏黑', '稍蜷', '濁響', '清晰', '稍凹', '軟粘']
                    , ['淺白', '蜷縮', '濁響', '模糊', '平坦', '硬滑']
                    , ['青綠', '蜷縮', '沉悶', '稍糊', '稍凹', '硬滑']])
    label = np.array(['是', '是', '是', '是', '是', '是', '是', '是', '否', '否', '否', '否', '否', '否', '否', '否', '否'])
    name = np.array(['色澤', '根蒂', '敲聲', '紋理', '臍部', '觸感'])
    return data, label, name

def splitXgData20(xgData, xgLabel):
    xgDataTrain = xgData[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16],:]
    xgDataTest = xgData[[3, 4, 7, 8, 10, 11, 12],:]
    xgLabelTrain = xgLabel[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16]]
    xgLabelTest = xgLabel[[3, 4, 7, 8, 10, 11, 12]]
    return xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest

  1. 創建基礎函數
    如計算熵、計算條件熵、信息增益、信息增益率等
# 定義一個常用函數 用來求numpy array中數值等於某值的元素數量
equalNums = lambda x,y: 0 if x is None else x[x==y].size


# 定義計算信息熵的函數
def singleEntropy(x):
    """計算一個輸入序列的信息熵"""
    # 轉換爲 numpy 矩陣
    x = np.asarray(x)
    # 取所有不同值
    xValues = set(x)
    # 計算熵值
    entropy = 0
    for xValue in xValues:
        p = equalNums(x, xValue) / x.size 
        entropy -= p * math.log(p, 2)
    return entropy
    
    
# 定義計算條件信息熵的函數
def conditionnalEntropy(feature, y):
    """計算 某特徵feature 條件下y的信息熵"""
    # 轉換爲numpy 
    feature = np.asarray(feature)
    y = np.asarray(y)
    # 取特徵的不同值
    featureValues = set(feature)
    # 計算熵值 
    entropy = 0
    for feat in featureValues:
        # 解釋:feature == feat 是得到取feature中所有元素值等於feat的元素的索引(類似這樣理解)
        #       y[feature == feat] 是取y中 feature元素值等於feat的元素索引的 y的元素的子集
        p = equalNums(feature, feat) / feature.size 
        entropy += p * singleEntropy(y[feature == feat])
    return entropy
    
    
# 定義信息增益
def infoGain(feature, y):
    return singleEntropy(y) - conditionnalEntropy(feature, y)


# 定義信息增益率
def infoGainRatio(feature, y):
    return 0 if singleEntropy(feature) == 0 else infoGain(feature, y) / singleEntropy(feature)

函數功能測試
使用李航數據測試函數 p62

# 使用李航數據測試函數 p62
lhData, lhLabel, lhName = createDataLH()
print("書中H(D)爲0.971,函數結果:" + str(round(singleEntropy(lhLabel), 3)))  
print("書中g(D, A1)爲0.083,函數結果:" + str(round(infoGain(lhData[:,0] ,lhLabel), 3)))  
print("書中g(D, A2)爲0.324,函數結果:" + str(round(infoGain(lhData[:,1] ,lhLabel), 3)))  
print("書中g(D, A3)爲0.420,函數結果:" + str(round(infoGain(lhData[:,2] ,lhLabel), 3)))  
print("書中g(D, A4)爲0.363,函數結果:" + str(round(infoGain(lhData[:,3] ,lhLabel), 3)))  
# 測試正常,與書中結果一致

運行結果

書中H(D)爲0.971,函數結果:0.971
書中g(D, A1)爲0.083,函數結果:0.083
書中g(D, A2)爲0.324,函數結果:0.324
書中g(D, A3)爲0.420,函數結果:0.42
書中g(D, A4)爲0.363,函數結果:0.363

使用西瓜數據測試函數 p75-p77

# 使用西瓜數據測試函數  p75-p77
xgData, xgLabel, xgName = createDataXG20()
print("書中Ent(D)爲0.998,函數結果:" + str(round(singleEntropy(xgLabel), 4)))  
print("書中Gain(D, 色澤)爲0.109,函數結果:" + str(round(infoGain(xgData[:,0] ,xgLabel), 4)))  
print("書中Gain(D, 根蒂)爲0.143,函數結果:" + str(round(infoGain(xgData[:,1] ,xgLabel), 4)))  
print("書中Gain(D, 敲聲)爲0.141,函數結果:" + str(round(infoGain(xgData[:,2] ,xgLabel), 4)))  
print("書中Gain(D, 紋理)爲0.381,函數結果:" + str(round(infoGain(xgData[:,3] ,xgLabel), 4)))  
print("書中Gain(D, 臍部)爲0.289,函數結果:" + str(round(infoGain(xgData[:,4] ,xgLabel), 4)))  
print("書中Gain(D, 觸感)爲0.006,函數結果:" + str(round(infoGain(xgData[:,5] ,xgLabel), 4)))  

運行結果

書中Ent(D)爲0.998,函數結果:0.9975
書中Gain(D, 色澤)爲0.109,函數結果:0.1081
書中Gain(D, 根蒂)爲0.143,函數結果:0.1427
書中Gain(D, 敲聲)爲0.141,函數結果:0.1408
書中Gain(D, 紋理)爲0.381,函數結果:0.3806
書中Gain(D, 臍部)爲0.289,函數結果:0.2892
書中Gain(D, 觸感)爲0.006,函數結果:0.006

測試結果與書中數據基本一致

  1. 創建樹生成相關函數
    如 特徵選取、數據分割、多數投票、樹生成、使用樹分類、樹信息統計
# 特徵選取
def bestFeature(data, labels, method = 'id3'):
    assert method in ['id3', 'c45'], "method 須爲id3或c45"
    data = np.asarray(data)
    labels = np.asarray(labels)
    # 根據輸入的method選取 評估特徵的方法:id3 -> 信息增益; c45 -> 信息增益率
    def calcEnt(feature, labels):
        if method == 'id3':
            return infoGain(feature, labels)
        elif method == 'c45' :
            return infoGainRatio(feature, labels)
    # 特徵數量  即 data 的列數量
    featureNum = data.shape[1]
    # 計算最佳特徵
    bestEnt = 0 
    bestFeat = -1
    for feature in range(featureNum):
        ent = calcEnt(data[:, feature], labels)
        if ent >= bestEnt:
            bestEnt = ent 
            bestFeat = feature
        # print("feature " + str(feature + 1) + " ent: " + str(ent)+ "\t bestEnt: " + str(bestEnt))
    return bestFeat, bestEnt 


# 根據特徵及特徵值分割原數據集  刪除data中的feature列,並根據feature列中的值分割 data和label
def splitFeatureData(data, labels, feature):
    """feature 爲特徵列的索引"""
    # 取特徵列
    features = np.asarray(data)[:,feature]
    # 數據集中刪除特徵列
    data = np.delete(np.asarray(data), feature, axis = 1)
    # 標籤
    labels = np.asarray(labels)
    
    uniqFeatures = set(features)
    dataSet = {}
    labelSet = {}
    for feat in uniqFeatures:
        dataSet[feat] = data[features == feat]
        labelSet[feat] = labels[features == feat]
    return dataSet, labelSet
    
    
# 多數投票 
def voteLabel(labels):
    uniqLabels = list(set(labels))
    labels = np.asarray(labels)

    finalLabel = 0
    labelNum = []
    for label in uniqLabels:
        # 統計每個標籤值得數量
        labelNum.append(equalNums(labels, label))
    # 返回數量最大的標籤
    return uniqLabels[labelNum.index(max(labelNum))]


# 創建決策樹
def createTree(data, labels, names, method = 'id3'):
    data = np.asarray(data)
    labels = np.asarray(labels)
    names = np.asarray(names)
    # 如果結果爲單一結果
    if len(set(labels)) == 1: 
        return labels[0] 
    # 如果沒有待分類特徵
    elif data.size == 0: 
        return voteLabel(labels)
    # 其他情況則選取特徵 
    bestFeat, bestEnt = bestFeature(data, labels, method = method)
    # 取特徵名稱
    bestFeatName = names[bestFeat]
    # 從特徵名稱列表刪除已取得特徵名稱
    names = np.delete(names, [bestFeat])
    # 根據選取的特徵名稱創建樹節點
    decisionTree = {bestFeatName: {}}
    # 根據最優特徵進行分割
    dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
    # 對最優特徵的每個特徵值所分的數據子集進行計算
    for featValue in dataSet.keys():
        decisionTree[bestFeatName][featValue] = createTree(dataSet.get(featValue), labelSet.get(featValue), names, method)
    return decisionTree 


# 樹信息統計 葉子節點數量 和 樹深度
def getTreeSize(decisionTree):
    nodeName = list(decisionTree.keys())[0]
    nodeValue = decisionTree[nodeName]
    leafNum = 0
    treeDepth = 0 
    leafDepth = 0
    for val in nodeValue.keys():
        if type(nodeValue[val]) == dict:
            leafNum += getTreeSize(nodeValue[val])[0]
            leafDepth = 1 + getTreeSize(nodeValue[val])[1] 
        else :
            leafNum += 1 
            leafDepth = 1 
        treeDepth = max(treeDepth, leafDepth)
    return leafNum, treeDepth 


# 使用模型對其他數據分類
def dtClassify(decisionTree, rowData, names):
    names = list(names)
    # 獲取特徵
    feature = list(decisionTree.keys())[0]
    # 決策樹對於該特徵的值的判斷字段
    featDict = decisionTree[feature]
    # 獲取特徵的列
    feat = names.index(feature)
    # 獲取數據該特徵的值
    featVal = rowData[feat]
    # 根據特徵值查找結果,如果結果是字典說明是子樹,調用本函數遞歸
    if featVal in featDict.keys():
        if type(featDict[featVal]) == dict:
            classLabel = dtClassify(featDict[featVal], rowData, names)
        else:
            classLabel = featDict[featVal] 
    return classLabel
  1. 樹可視化
    畫圖的方法主要參考《機器學習實戰》,細節有較多改動。畫圖時須提前配置好支持中文的畫圖。
# 可視化 主要源自《機器學習實戰》
import matplotlib.pyplot as plt 

decisionNodeStyle = dict(boxstyle = "sawtooth", fc = "0.8")
leafNodeStyle = {"boxstyle": "round4", "fc": "0.8"}
arrowArgs = {"arrowstyle": "<-"}


# 畫節點
def plotNode(nodeText, centerPt, parentPt, nodeStyle):
    createPlot.ax1.annotate(nodeText, xy = parentPt, xycoords = "axes fraction", xytext = centerPt
                            , textcoords = "axes fraction", va = "center", ha="center", bbox = nodeStyle, arrowprops = arrowArgs)


# 添加箭頭上的標註文字
def plotMidText(centerPt, parentPt, lineText):
    xMid = (centerPt[0] + parentPt[0]) / 2.0
    yMid = (centerPt[1] + parentPt[1]) / 2.0 
    createPlot.ax1.text(xMid, yMid, lineText)
    
    
# 畫樹
def plotTree(decisionTree, parentPt, parentValue):
    # 計算寬與高
    leafNum, treeDepth = getTreeSize(decisionTree) 
    # 在 1 * 1 的範圍內畫圖,因此分母爲 1
    # 每個葉節點之間的偏移量
    plotTree.xOff = plotTree.figSize / (plotTree.totalLeaf - 1)
    # 每一層的高度偏移量
    plotTree.yOff = plotTree.figSize / plotTree.totalDepth
    # 節點名稱
    nodeName = list(decisionTree.keys())[0]
    # 根節點的起止點相同,可避免畫線;如果是中間節點,則從當前葉節點的位置開始,
    #      然後加上本次子樹的寬度的一半,則爲決策節點的橫向位置
    centerPt = (plotTree.x + (leafNum - 1) * plotTree.xOff / 2.0, plotTree.y)
    # 畫出該決策節點
    plotNode(nodeName, centerPt, parentPt, decisionNodeStyle)
    # 標記本節點對應父節點的屬性值
    plotMidText(centerPt, parentPt, parentValue)
    # 取本節點的屬性值
    treeValue = decisionTree[nodeName]
    # 下一層各節點的高度
    plotTree.y = plotTree.y - plotTree.yOff
    # 繪製下一層
    for val in treeValue.keys():
        # 如果屬性值對應的是字典,說明是子樹,進行遞歸調用; 否則則爲葉子節點
        if type(treeValue[val]) == dict:
            plotTree(treeValue[val], centerPt, str(val))
        else:
            plotNode(treeValue[val], (plotTree.x, plotTree.y), centerPt, leafNodeStyle)
            plotMidText((plotTree.x, plotTree.y), centerPt, str(val))
            # 移到下一個葉子節點
            plotTree.x = plotTree.x + plotTree.xOff
    # 遞歸完成後返回上一層
    plotTree.y = plotTree.y + plotTree.yOff
    
    
# 畫出決策樹
def createPlot(decisionTree):
    fig = plt.figure(1, facecolor = "white")
    fig.clf()
    axprops = {"xticks": [], "yticks": []}
    createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
    # 定義畫圖的圖形尺寸
    plotTree.figSize = 1.5 
    # 初始化樹的總大小
    plotTree.totalLeaf, plotTree.totalDepth = getTreeSize(decisionTree)
    # 葉子節點的初始位置x 和 根節點的初始層高度y
    plotTree.x = 0 
    plotTree.y = plotTree.figSize
    plotTree(decisionTree, (plotTree.figSize / 2.0, plotTree.y), "")
    plt.show()
  1. 使用示例數據進行測試
    使用李航的數據
# 使用李航數據測試函數 p62
lhData, lhLabel, lhName = createDataLH()
lhTree = createTree(lhData, lhLabel, lhName, method = 'id3')
print(lhTree)
createPlot(lhTree)

輸出如下
李航數據的結果

# 使用西瓜數據測試函數  p75-p77
xgData, xgLabel, xgName = createDataXG20()
xgTree = createTree(xgData, xgLabel, xgName, method = 'id3')
print(xgTree)
createPlot(xgTree)

輸出如下
西瓜數據2.0的結果備註:西瓜數據的結果與書上可能有差異,原因是西瓜數據在特徵選擇時,有多個特徵信息增益(率)是相同的,本文算法的選擇就容易和書上的選擇有出入。

  1. 預剪枝
    僅供參考
# 創建預剪枝決策樹
def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method = 'id3'):
    """
    預剪枝 需要使用測試數據對每次的劃分進行評估
         策略說明:原本如果某節點劃分前後的測試結果沒有提升,根據奧卡姆剃刀原則將不進行劃分(即執行剪枝),但考慮到這種策略容易造成欠擬合,
                   且不能排除後續劃分有進一步提升的可能,因此,沒有提升仍保留劃分,即不剪枝
         另外:周志華的書上評估的是某一個節點劃分前後對該層所有數據綜合評估,如評估對臍部 凹陷下色澤是否劃分,
               書上取的色澤劃分前的精度是71.4%(5/7),劃分後的精度是57.1%(4/7),都是臍部下三個特徵(凹陷,稍凹,平坦)所有的數據的精度,計算也不易
               而我覺得實際計算時,只對當前節點下的數據劃分前後進行評估即可,如臍部凹陷時有三個測試樣本,
               三個樣本色澤劃分前的精度是2/3=66.7%,色澤劃分後的精度是1/3=33.3%,因此判斷不劃分
    """
    trainData = np.asarray(dataTrain)
    labelTrain = np.asarray(labelTrain)
    testData = np.asarray(dataTest)
    labelTest = np.asarray(labelTest)
    names = np.asarray(names)
    # 如果結果爲單一結果
    if len(set(labelTrain)) == 1: 
        return labelTrain[0] 
    # 如果沒有待分類特徵
    elif trainData.size == 0: 
        return voteLabel(labelTrain)
    # 其他情況則選取特徵 
    bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method = method)
    # 取特徵名稱
    bestFeatName = names[bestFeat]
    # 從特徵名稱列表刪除已取得特徵名稱
    names = np.delete(names, [bestFeat])
    # 根據最優特徵進行分割
    dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)

    # 預剪枝評估
    # 劃分前的分類標籤
    labelTrainLabelPre = voteLabel(labelTrain)
    labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
    # 劃分後的精度計算 
    if dataTest is not None: 
        dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
        # 劃分前的測試標籤正確比例
        labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
        # 劃分後 每個特徵值的分類標籤正確的數量
        labelTrainEqNumPost = 0
        for val in labelTrainSet.keys():
            labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
        # 劃分後 正確的比例
        labelTestRatioPost = labelTrainEqNumPost / labelTest.size 
    
    # 如果沒有評估數據 但劃分前的精度等於最小值0.5 則繼續劃分
    if dataTest is None and labelTrainRatioPre == 0.5:
        decisionTree = {bestFeatName: {}}
        for featValue in dataTrainSet.keys():
            decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
                                      , None, None, names, method)
    elif dataTest is None:
        return labelTrainLabelPre 
    # 如果劃分後的精度相比劃分前的精度下降, 則直接作爲葉子節點返回
    elif labelTestRatioPost < labelTestRatioPre:
        return labelTrainLabelPre
    else :
        # 根據選取的特徵名稱創建樹節點
        decisionTree = {bestFeatName: {}}
        # 對最優特徵的每個特徵值所分的數據子集進行計算
        for featValue in dataTrainSet.keys():
            decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
                                      , dataTestSet.get(featValue), labelTestSet.get(featValue)
                                      , names, method)
    return decisionTree 

預剪枝測試

# 將西瓜數據2.0分割爲測試集和訓練集
xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest = splitXgData20(xgData, xgLabel)
# 生成不剪枝的樹
xgTreeTrain = createTree(xgDataTrain, xgLabelTrain, xgName, method = 'id3')
# 生成預剪枝的樹
xgTreePrePruning = createTreePrePruning(xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest, xgName, method = 'id3')
# 畫剪枝前的樹
print("剪枝前的樹")
createPlot(xgTreeTrain)
# 畫剪枝後的樹
print("剪枝後的樹")
createPlot(xgTreePrePruning)

輸出結果
預剪枝結果對比備註,由於特徵選擇的問題,結果與書上有差異。

  1. 後剪枝

後剪枝評估時需要劃分前的標籤,這裏思考兩種方法:
        一是,不改變原來的訓練函數,評估時使用訓練數據對劃分前的節點標籤重新打標
        二是,改進訓練函數,在訓練的同時爲每個節點增加劃分前的標籤,這樣可以保證評估時只使用測試數據,避免再次使用大量的訓練數據
        這裏採用第二種方法 寫新的函數 createTreeWithLabel,當然也可以修改createTree來添加參數實現

另外,後剪枝的程序代碼中有很多過程中的提示信息,已註釋掉。

# 創建決策樹 帶預劃分標籤
def createTreeWithLabel(data, labels, names, method = 'id3'):
    data = np.asarray(data)
    labels = np.asarray(labels)
    names = np.asarray(names)
    # 如果不劃分的標籤爲
    votedLabel = voteLabel(labels)
    # 如果結果爲單一結果
    if len(set(labels)) == 1: 
        return votedLabel 
    # 如果沒有待分類特徵
    elif data.size == 0: 
        return votedLabel
    # 其他情況則選取特徵 
    bestFeat, bestEnt = bestFeature(data, labels, method = method)
    # 取特徵名稱
    bestFeatName = names[bestFeat]
    # 從特徵名稱列表刪除已取得特徵名稱
    names = np.delete(names, [bestFeat])
    # 根據選取的特徵名稱創建樹節點 劃分前的標籤votedPreDivisionLabel=_vpdl
    decisionTree = {bestFeatName: {"_vpdl": votedLabel}}
    # 根據最優特徵進行分割
    dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
    # 對最優特徵的每個特徵值所分的數據子集進行計算
    for featValue in dataSet.keys():
        decisionTree[bestFeatName][featValue] = createTreeWithLabel(dataSet.get(featValue), labelSet.get(featValue), names, method)
    return decisionTree 


# 將帶預劃分標籤的tree轉化爲常規的tree
# 函數中進行的copy操作,原因見有道筆記 【YL20190621】關於Python中字典存儲修改的思考
def convertTree(labeledTree):
    labeledTreeNew = labeledTree.copy()
    nodeName = list(labeledTree.keys())[0]
    labeledTreeNew[nodeName] = labeledTree[nodeName].copy()
    for val in list(labeledTree[nodeName].keys()):
        if val == "_vpdl": 
            labeledTreeNew[nodeName].pop(val)
        elif type(labeledTree[nodeName][val]) == dict:
            labeledTreeNew[nodeName][val] = convertTree(labeledTree[nodeName][val])
    return labeledTreeNew


# 後剪枝 訓練完成後決策節點進行替換評估  這裏可以直接對xgTreeTrain進行操作
def treePostPruning(labeledTree, dataTest, labelTest, names):
    newTree = labeledTree.copy()
    dataTest = np.asarray(dataTest)
    labelTest = np.asarray(labelTest)
    names = np.asarray(names)
    # 取決策節點的名稱 即特徵的名稱
    featName = list(labeledTree.keys())[0]
    # print("\n當前節點:" + featName)
    # 取特徵的列
    featCol = np.argwhere(names==featName)[0][0]
    names = np.delete(names, [featCol])
    # print("當前節點劃分的數據維度:" + str(names))
    # print("當前節點劃分的數據:" )
    # print(dataTest)
    # print(labelTest)
    # 該特徵下所有值的字典
    newTree[featName] = labeledTree[featName].copy()
    featValueDict = newTree[featName]
    featPreLabel = featValueDict.pop("_vpdl")
    # print("當前節點預劃分標籤:" + featPreLabel)
    # 是否爲子樹的標記
    subTreeFlag = 0
    # 分割測試數據 如果有數據 則進行測試或遞歸調用  np的array我不知道怎麼判斷是否None, 用is None是錯的
    dataFlag = 1 if sum(dataTest.shape) > 0 else 0
    if dataFlag == 1:
        # print("當前節點有劃分數據!")
        dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, featCol)
    for featValue in featValueDict.keys():
        # print("當前節點屬性 {0} 的子節點:{1}".format(featValue ,str(featValueDict[featValue])))
        if dataFlag == 1 and type(featValueDict[featValue]) == dict:
            subTreeFlag = 1 
            # 如果是子樹則遞歸
            newTree[featName][featValue] = treePostPruning(featValueDict[featValue], dataTestSet.get(featValue), labelTestSet.get(featValue), names)
            # 如果遞歸後爲葉子 則後續進行評估
            if type(featValueDict[featValue]) != dict:
                subTreeFlag = 0 
            
        # 如果沒有數據  則轉換子樹
        if dataFlag == 0 and type(featValueDict[featValue]) == dict: 
            subTreeFlag = 1 
            # print("當前節點無劃分數據!直接轉換樹:"+str(featValueDict[featValue]))
            newTree[featName][featValue] = convertTree(featValueDict[featValue])
            # print("轉換結果:" + str(convertTree(featValueDict[featValue])))
    # 如果全爲葉子節點, 評估需要劃分前的標籤,這裏思考兩種方法,
    #     一是,不改變原來的訓練函數,評估時使用訓練數據對劃分前的節點標籤重新打標
    #     二是,改進訓練函數,在訓練的同時爲每個節點增加劃分前的標籤,這樣可以保證評估時只使用測試數據,避免再次使用大量的訓練數據
    #     這裏考慮第二種方法 寫新的函數 createTreeWithLabel,當然也可以修改createTree來添加參數實現
    if subTreeFlag == 0:
        ratioPreDivision = equalNums(labelTest, featPreLabel) / labelTest.size
        equalNum = 0
        for val in labelTestSet.keys():
            equalNum += equalNums(labelTestSet[val], featValueDict[val])
        ratioAfterDivision = equalNum / labelTest.size 
        # print("當前節點預劃分標籤的準確率:" + str(ratioPreDivision))
        # print("當前節點劃分後的準確率:" + str(ratioAfterDivision))
        # 如果劃分後的測試數據準確率低於劃分前的,則劃分無效,進行剪枝,即使節點等於預劃分標籤
        # 注意這裏取的是小於,如果有需要 也可以取 小於等於
        if ratioAfterDivision < ratioPreDivision:
            newTree = featPreLabel 
    return newTree

代碼測試,我對自己訓練的模型和書上生成的模型都進行了測試,這裏篇幅限制,且儘量保持與書中一致,僅提供書上模型的後剪枝

# 書中的樹結構 p81 p83
xgTreeBeforePostPruning = {"臍部": {"_vpdl": "是"
                                   , '凹陷': {'色澤':{"_vpdl": "是", '青綠': '是', '烏黑': '是', '淺白': '否'}}
                                   , '稍凹': {'根蒂':{"_vpdl": "是"
                                                  , '稍蜷': {'色澤': {"_vpdl": "是"
                                                                  , '青綠': '是'
                                                                  , '烏黑': {'紋理': {"_vpdl": "是"
                                                                               , '稍糊': '是', '清晰': '否', '模糊': '是'}}
                                                                  , '淺白': '是'}}
                                                  , '蜷縮': '否'
                                                  , '硬挺': '是'}}
                                   , '平坦': '否'}}
xgTreePostPruning = treePostPruning(xgTreeBeforePostPruning, xgDataTest, xgLabelTest, xgName)
createPlot(convertTree(xgTreeBeforePostPruning))
createPlot(xgTreePostPruning)

輸出結果
書中模型的後剪枝結論:對比書中p81和p83的圖4.5和圖4.7,與以上程序輸出一致!
圖p81
p81的圖
圖p83
圖p83

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章