ID3算法相對簡單,weka的實現也容易理解。首先介紹一下大致算法。算法概述如下。
1.選擇一種度量(ID3選擇的是信息增益),計算每個屬性對於該度量的值。
2.根據結果選擇一個屬性進行分支。
3.如果每個分支全部屬於一個類或者已經沒有候選屬性。則停止,否則對每個分支進行1,2操作。
下面對weka的ID3 class 作介紹,主要涉及到makeTree(Instances data),computeInfoGain(data, att),splitData(Instances data, Attribute att)三個函數。其中makeTree是入口函數,computeInfoGain的作用是計算信息增益,splitData的作用是分支。首先看makeTree函數。
private void makeTree(Instances data) throws Exception {
// Check if no instances have reached this node.
if (data.numInstances() == 0) {
m_Attribute = null;
m_ClassValue = Instance.missingValue();
m_Distribution = new double[data.numClasses()];
return;
}
// Compute attribute with maximum information gain.
double[] infoGains = new double[data.numAttributes()];
Enumeration attEnum = data.enumerateAttributes();
/**
* 對每個屬性計算信息增益
*/
while (attEnum.hasMoreElements()) {
Attribute att = (Attribute) attEnum.nextElement();
infoGains[att.index()] = computeInfoGain(data, att);
}
m_Attribute = data.attribute(Utils.maxIndex(infoGains));
// Make leaf if information gain is zero.
// Otherwise create successors.
if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
m_Attribute = null;
m_Distribution = new double[data.numClasses()];
Enumeration instEnum = data.enumerateInstances();
while (instEnum.hasMoreElements()) {
Instance inst = (Instance) instEnum.nextElement();
m_Distribution[(int) inst.classValue()]++;
}
Utils.normalize(m_Distribution);
m_ClassValue = Utils.maxIndex(m_Distribution);
m_ClassAttribute = data.classAttribute();
} else {
Instances[] splitData = splitData(data, m_Attribute);
m_Successors = new Id3[m_Attribute.numValues()];
/**
* 這裏對每個分支繼續調用id3.makeTree(instatnces)。
*/
for (int j = 0; j < m_Attribute.numValues(); j++) {
m_Successors[j] = new Id3();
m_Successors[j].makeTree(splitData[j]);
}
}
}
通過註釋,應該不難理解大致過程。這裏需要注意的是 程序裏經常會出現Enumeration,這其實就是現在的Ieratorer,當時jdk版本較低,所以用的Enumeration,忽視掉就好了。
下面是splitData。只是按照類的值進行分支,也很容易理解。
private double computeInfoGain(Instances data, Attribute att)
throws Exception {
double infoGain = computeEntropy(data);
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
infoGain -= ((double) splitData[j].numInstances() /
(double) data.numInstances()) *
computeEntropy(splitData[j]);
}
}
return infoGain;
}
至於 computeInfoGain對照公式就很容易理解了。這裏只貼出代碼
private double computeInfoGain(Instances data, Attribute att)
throws Exception {
double infoGain = computeEntropy(data);
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
infoGain -= ((double) splitData[j].numInstances() /
(double) data.numInstances()) *
computeEntropy(splitData[j]);
}
}
return infoGain;
}