決策樹算法通常用來解決有監督的分類問題,本章主要講解決策樹中的 ID3 算法。
1.工作原理
決策樹根據特徵對數據集進行劃分和分類,所以關鍵在於如何選擇特徵。這裏就用到了信息論的知識。在信息論與概率統計中,熵表示隨機變量不確定性的度量。熵越大,隨機變量的不確定性就越大。即在未分類之前,數據集是無序的,熵是最大的。而通過分類,可以使數據集變得更加有序,熵減小。原始數據集 的熵 是這樣計算的:
表示數據集中樣本的總個數, 有 個分類 ,, 是屬於類 的樣本個數,那麼。
然後開始選取特徵,數據集根據選取特徵的取值劃分成若干個子數據集,計算劃分之後的熵。
假設選取特徵 , 有 個不同的取值 ,根據特徵 的取值將 劃分爲 個子數據集 , 爲 的樣本個數,。記子集 中屬於類 的樣本的集合爲 ,即 , 爲 的樣本個數,那麼劃分之後的熵 爲:
用原始數據集的熵減去劃分之後的熵就得到了信息增益 :
我們總是優先選取使得數據集信息增益最大的特徵。對於得到的每個子數據集,需要先把選取過的特徵去除掉,然後再重複上面的操作,即根據餘下的特徵繼續劃分子數據集,直到子數據集中所有的樣本都屬於同一個分類,或者用完了所有的特徵,這種情況下通常採用多數表決的方法,即該子數據集所屬的分類的類標籤爲該子數據集中類標籤數目最多的類標籤。因此,可以採用遞歸的方式建立決策樹。
2.優缺點及適用範圍
優點:計算複雜度不高,輸出結果易於理解,對中間值的缺失不敏感,可以處理不相關特徵數據。
缺點:可能會產生過度匹配的問題。
適用數據類型:數值型和標稱型。
3.代碼實現
本文實現了書中的代碼,並從中發現了一些小問題,這裏羅列一下:
-
字典的 keys() 方法返回值類型爲 dict_keys,其並不能取索引,因此需要將其轉成 list;
-
建樹時,傳入的參數 labels 會刪去第一個選取的特徵,即傳入的 labels 會發生改變,因此最好傳入 labels 的淺拷貝;
-
在存儲分類器時會報錯,所以採用二進制進行讀寫;
-
在使用 matplotlib 繪圖時,不能顯示中文,原中文字符變成小方格,只需在文件中加入兩行代碼即可:
from pylab import * mpl.rcParams['font.sans-serif'] = ['SimHei']
trees.py 文件(main方法的語句塊最好單獨執行):
from math import log
import operator
import pickle
import treePlotter
def calcShannonEnt(dataSet):
'''計算數據集的熵'''
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt
def createDataSet():
'''創建數據集'''
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
def splitDataSet(dataSet,axis,value):
'''按照給定特徵劃分數據集'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
'''選擇最好的數據集劃分方式'''
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
# 創建唯一的分類標籤列表
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 計算每種劃分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# 計算最好的信息增益
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
'''計算出現最多的類標籤'''
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),
reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
'''建樹'''
classList = [example[-1] for example in dataSet]
# 類別完全相同則停止繼續劃分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍歷完所有特徵時返回出現次數最多的類標籤
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet,bestFeat,value),subLabels)
return myTree
def classify(inputTree,featLabels,testVec):
'''使用決策樹的分類函數'''
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]
return classLabel
def storeTree(inputTree,filename):
'''存儲決策樹'''
fw = open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
'''加載決策樹'''
fr = open(filename,'rb')
return pickle.load(fr)
if __name__ == '__main__':
# 1.計算信息熵
# myDat,labels = createDataSet()
# print(myDat)
# print(calcShannonEnt(myDat))
# myDat[0][-1] = 'maybe'
# print(myDat)
# print(calcShannonEnt(myDat))
# 2.按屬性值分割數據集
# myDat,labels = createDataSet()
# print(splitDataSet(myDat,0,1))
# print(splitDataSet(myDat,0,0))
# 3.選擇使得信息增益最大的特徵
# myDat,labels = createDataSet()
# print(chooseBestFeatureToSplit(myDat))
# 4.創建決策樹
# myDat,labels = createDataSet()
# myTree = createTree(myDat,labels)
# print(myTree)
# 5.使用決策樹進行分類
# myDat,labels = createDataSet()
# myTree = createTree(myDat,labels[:])
# print(myTree)
# print(classify(myTree,labels,[1,0]))
# print(classify(myTree,labels,[1,1]))
# 6.存儲和加載決策樹
# myDat,labels = createDataSet()
# myTree = createTree(myDat,labels[:])
# storeTree(myTree,'classifierStorage.txt')
# print(grabTree('classifierStorage.txt'))
# 7.使用決策樹預測隱形眼鏡類型
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age','prescript','astigmatic','tearRate']
lensesTree = createTree(lenses,lensesLabels[:])
print(lensesTree)
treePlotter.createPlot(lensesTree)
treePlotter.py 文件(main方法的語句塊最好單獨執行):
'''使用文本註解繪製樹節點'''
import matplotlib.pyplot as plt
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
# 定義文本框和箭頭格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
'''繪製帶箭頭的註解'''
createPlot.ax1.annotate(nodeTxt,xy=parentPt,
xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va="center",ha="center",
bbox=nodeType,arrowprops=arrow_args )
def getNumLeafs(myTree):
'''獲取葉節點的數目'''
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
'''獲取樹的層數'''
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt,parentPt,txtString):
'''在父子節點間填充文本信息'''
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)
def plotTree(myTree,parentPt,nodeTxt):
'''計算寬與高'''
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff
+ (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDic = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDic.keys():
if type(secondDic[key]).__name__ == 'dict':
plotTree(secondDic[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDic[key],(plotTree.xOff,plotTree.yOff),
cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
# def createPlot():
# fig = plt.figure(1,facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111,frameon=False)
# plotNode(U'決策節點',(0.5, 0.1),(0.1, 0.5),decisionNode)
# plotNode(U'葉節點',(0.8, 0.1),(0.3, 0.8),leafNode)
# plt.show()
if __name__ == '__main__':
# 1.繪製樹節點,這裏使用註釋掉的函數
# createPlot()
# 2.測試預定義的樹結構,葉子節點和樹層數函數
# print(retrieveTree(1))
# myTree = retrieveTree(0)
# print(getNumLeafs(myTree))
# print(getTreeDepth(myTree))
# 3.繪製決策樹
myTree = retrieveTree(0)
createPlot(myTree)
myTree['no surfacing'][3] = 'maybe'
print(myTree)
createPlot(myTree)
4.相關文件
這裏給出本文用到的相關文件。
鏈接: https://pan.baidu.com/s/1SRHqvRF8Q0iZjZs_tpm35Q 提取碼: 9c4g
參考文獻
- 機器學習實戰
- 統計學習方法