由於近期學業繁重QAQ,所以我就不說廢話了,直接上代碼~
運行效果
代碼
from math import log
import operator
import matplotlib.pyplot as plt
#定義文本框和箭頭格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#畫樹
#使用文本註解繪製樹節點
#繪製帶箭頭的註解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,
xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,
arrowprops=arrow_args)
#在父子節點間填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),
cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),
cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getNumLeafs(inTree))
plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
plotTree(inTree,(0.5,1.0),'')
plt.show()
#創建數據集
def createDataSet():
dataSet=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels=['no surfacing','flippers']
return dataSet,labels
#計算給定數據的香農熵
#熵值越高,混合的數據越多,越無序
#我們可以在數據集中添加更多的分類
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
#數據字典,鍵值爲最後一列的數值"yes"or"no"
labelCounts={}
for featVec in dataSet:
#爲所有可能分類創建字典
#"yes"or"no"
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
#以2爲㡳求對數
shannonEnt-=prob*log(prob,2)
return shannonEnt
#按照給定特徵劃分數據集
#輸入的參數爲:待劃分的數據集,
#劃分數據集的特徵(第幾列),
#特徵的返回值(這一列的值爲多少)
#返回的是符合這一列的值的每一行,
#並且將這一列的數據去掉了
def splitDataSet(dataSet,axis,value):
retDataSet=[]
#遍歷整個數據集
#featVec:[1, 1, 'yes']
for featVec in dataSet:
#print('featVec:')
#print(featVec)
#抽取其中符合特徵的
#featVec[axis]表示[1, 1, 'yes']中的第axis+1個
if featVec[axis]==value:
#保存這一列前面的數據
reducedFeatVec=featVec[:axis]
#print('reducedFeatVec:')
#print(reducedFeatVec)
#保存這一列後面的數據
reducedFeatVec.extend(featVec[axis+1:])
#print('reducedFeatVec:')
#print(reducedFeatVec)
retDataSet.append(reducedFeatVec)
#print('retDataSet:')
#print(retDataSet)
return retDataSet
#選擇最好的數據集劃分方式
def chooseBestFeatureToSplit(dataSet):
#numFeatures:2
numFeatures=len(dataSet[0])-1
#計算香農熵
baseEntropy=calcShannonEnt(dataSet)
bestInfoGain=0.0
bestFeature=-1
#i:0,1
for i in range(numFeatures):
#取出dataSet的第i列
featList=[example[i] for example in dataSet]
#print('featList:')
#print(featList)
#弄成一個set,去掉其中相同的元素
uniqueVals=set(featList)
#print('uniqueVals:')
#print(uniqueVals)
newEntropy=0.0
for value in uniqueVals:
#按照第i列,值爲value的去劃分
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
#計算劃分後的熵值
newEntropy+=prob*calcShannonEnt(subDataSet)
infoGain=baseEntropy-newEntropy
#判斷是否更優
if(infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
#返回劃分的最優類別
#表示按照第i列去劃分
return bestFeature
#傳入的是分類名稱的列表
#返回出現次數最多的分類的名稱
def majorityCnt(classList):
#創建字典,鍵值爲classList中唯一值
#字典的值爲classList中每隔標籤出現的頻率
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
#按照字典值的順序從大到小排序
sortedClassCount=sorted(classCount,iteritems(),
key=operator.itemgetter(1),reverse=True)
#返回出現次數最多的分類的名稱
return sortedClassCount[0][0]
#創建樹
#傳入參數爲數據集與標籤列表
def createTree(dataSet,labels):
#得到分類名稱的標籤"yes"or"no"
#['yes', 'yes', 'no', 'no', 'no']
classList=[example[-1] for example in dataSet]
#print('classList:')
#print(classList)
#遞歸結束的第一個條件
#所有的類標籤完全相同
if classList.count(classList[0])==len(classList):
return classList[0]
#遞歸結束的第二個條件
#使用完了所有的特徵,仍然不能將數
#據集劃分成僅包含唯一類別的分組
#此時無法簡單地返回唯一的類標籤,
#直接返回出現次數最多的類標籤
if len(dataSet[0])==1:
return majorityCnt(classList)
#bestFeat是最好的劃分方式對應的列的下標
bestFeat=chooseBestFeatureToSplit(dataSet)
#labels中這一列信息對應的類別名稱
bestFeatLabel=labels[bestFeat]
#樹
myTree={bestFeatLabel:{}}
#將labels中的這一類別delete
del(labels[bestFeat])
#這一類別對應的列的值
featValues=[example[bestFeat] for example in dataSet]
#print('featValues:')
#print(featValues)
#set 去掉列中相同的值
uniqueVals=set(featValues)
for value in uniqueVals:
#去掉最優類別後剩下的類別
subLabels=labels[:]
#print('subLabels:')
#print(subLabels)
#print('bestFeatLabel:')
#print(bestFeatLabel)
#print(value)
#myTree['no surfacing'][0]
#myTree['no surfacing'][1]
#......
myTree[bestFeatLabel][value]=createTree(
#按照第bestFeat列,值爲value的去劃分
splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
#獲取葉節點的數目
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
secondDir=myTree[firstStr]
for key in secondDir.keys():
#子節點爲字典類型,則該結點也是一個判斷結點
#需要遞歸調用getNumLeafs函數
if type(secondDir[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDir[key])
#該結點爲葉子節點,葉子數+1
else:
numLeafs+=1
return numLeafs
#獲取樹的層數
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
def main():
dataSet,labels=createDataSet()
chooseBestFeatureToSplit(dataSet)
#{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
myTree=createTree(dataSet,labels)
print('myTree:')
print(myTree)
createPlot(myTree)
#i=getNumLeafs(myTree)
#print(i)
#i=getTreeDepth(myTree)
#print(i)
#i=chooseBestFeatureToSplit(dataSet)
#print(i)
#shannonEnt=calcShannonEnt(dataSet)
#print(shannonEnt)
#增加一個類別後再測試信息熵,發現熵值增大
#dataSet[0][-1]='maybe'
#shannonEnt=calcShannonEnt(dataSet)
#print(shannonEnt)
#retDataSet=splitDataSet(dataSet,0,1)
#print('retDataSet:')
#print(retDataSet)
#retDataSet=splitDataSet(dataSet,0,0)
#print('retDataSet:')
#print(retDataSet)
if __name__=='__main__':
main()