貝葉斯分類是利用概率統計知識進行分類的算法,其分類原理是貝葉斯定理。貝葉斯定理的公式如下:
貝葉斯公式表明,我們可以從先驗概率P(A)、條件概率P(B|A)和證據P(B)來計算出後驗概率。
樸素貝葉斯分類器就是假設證據之間各個條件相互獨立的基礎上,根據計算的後驗概率選擇各類別後驗概率最大的類別作爲目標證據的類別。
構建樸素貝葉斯分類器的步驟如下:
1、根據訓練樣例分別計算每個類別出現的概率P(Ai),
2、對每個特徵屬性計算所有劃分的條件概率P(Bi|Ai),
3、對每個類別計算P(B|Ai)*P(Ai),
4、選擇3步驟中數值最大項作爲B的類別Ak。
在實際編碼中,並沒有計算各個概率,而是構建了各個屬性在各個類別中出現的頻次數,根據目標特徵計算相應的概率,這樣的好處是容易存儲和讀取,便於使用,具體代碼如下:
def bayesian(inX,tranSet,labels): ''' 貝葉斯分類器 :param tranSet:特徵矩陣 :param labels: 類別 :return: ''' labelsTree = {} m,n = tranSet.shape labelsCount = {} xCount = zeros((n,1)) for i in arange(m): if labels[i] not in labelsTree: labelsTree[labels[i]] = {} labelsCount[labels[i]] = {} for j in arange(n): if j not in labelsTree[labels[i]]: labelsTree[labels[i]][j] = {} #labelsTree[labels[i]][tranSet[i][j]] = labelsTree[labels[i]][tranSet[i][j]].get(labels[i][tranSet[i][j]],0) + 1 labelsTree[labels[i]][j][tranSet[i,j]] = labelsTree[labels[i]][j].get(tranSet[i,j],0) + 1 labelsCount[labels[i]][j] = labelsCount[labels[i]].get(j,0) + 1 if inX[j] == tranSet[i,j]: xCount[j] = xCount[j] + 1 pVector = {} xProp = (xCount/sum(xCount)).cumprod()[-1] for key in labelsTree.keys(): for i in arange(n): pVector[key] = pVector.get(key,1) * labelsTree[key][i].get(inX[i],1)/labelsCount[key].get(i,1) pVector[key] = pVector[key] * sum(array([x for x in labelsCount[key].values()]))/m return pVector,array([x for x in pVector.values()],dtype = 'float')/xProp
測試代碼如下:
from numpy import * import ml data = [['<=30','high','no','fair'], ['<=30','high','no','excellent'], ['31...40','high','no','fair'], ['>40','medium','no','fair'], ['>40','low','yes','fair'], ['>40','low','yes','excellent'], ['31...40','low','yes','excellent'], ['<=30','medium','no','fair'], ['<=30','low','yes','fair'], ['>40','medium','yes','fair'], ['<=30','medium','yes','excellent'], ['31...40','medium','no','excellent'], ['31...40','high','yes','fair'], ['>40','medium','no','excellent']] label = ['no','no','yes','yes','yes','no','yes','no','yes','yes','yes','yes','yes','no'] inX = ['<=30','medium','yes','fair'] pV = ml.bayesian(array(inX),array(data),array(label)) print(pV)