ADTree中主要有兩種節點,一種是PreditionNode,一種是SplitNode。weka實現中就對應定義了這兩個數據結構。
public class PredictionNode
{
double value;
FastVector children;
}
value存 a或者b(具體含義請看論文)。children存SplitNode.
public abstract class Splitter
{
public int orderAdded;
/**
** 還有其他一些抽象函數,這裏暫且省略
**/
}
作者實現了Splitter的兩個子類,分別是TwoWayNominalSplit和TwoWayNumericSplit。具體的結構等到用的時候再說。
主類自然是ADTree。主函數入口
public class ADTree
{
protected Instances m_trainInstances;
/** The root of the tree */
protected PredictionNode m_root = null;
/** The number of the last splitter added to the tree */
protected int m_lastAddedSplitNum = 0;
/** An array containing the indices to the numeric attributes in the data */
protected int[] m_numericAttIndices;
/** An array containing the indices to the nominal attributes in the data */
protected int[] m_nominalAttIndices;
/** The total weight of the instances - used to speed Z calculations */
protected double m_trainTotalWeight;
/** The training instances with positive class - referencing the training dataset */
protected ReferenceInstances m_posTrainInstances;
/** The training instances with negative class - referencing the training dataset */
protected ReferenceInstances m_negTrainInstances;
/** The best node to insert under, as found so far by the latest search */
protected PredictionNode m_search_bestInsertionNode;
/** The best splitter to insert, as found so far by the latest search */
protected Splitter m_search_bestSplitter;
/** The smallest Z value found so far by the latest search */
protected double m_search_smallestZ;
/** The positive instances that apply to the best path found so far */
protected Instances m_search_bestPathPosInstances;
/** The negative instances that apply to the best path found so far */
protected Instances m_search_bestPathNegInstances;
}
這裏列出最重要的成員變量,作者都給瞭解釋,讀者可先熟悉一下,等分析函數的時候再做說明。
下面是主函數入口
public void buildClassifier(Instances instances) throws Exception {
/**
**初始化m_trainInstances,m_posTrainInstances,
m_negTrainInstances,m_root,m_numericAttIndices,
m_nominalAttIndices,m_trainTotalWeight
**/
initClassifier(instances);
// 典型的AdaBoost 算法
for (int T = 0; T < m_boostingIterations; T++) boost();
}
initClassifier(instances)函數不打算貼出來了,完成的功能就是註釋裏的功能,讀者可以自行查閱一下。這裏分析核心函數 boost() public void boost() throws Exception {
if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
throw new Exception("Trying to boost with no training data");
/**
* 迭代入口。真正選取splitNode的地方。即賦值成員變量m_search_bestSplitter,m_search_bestInsertionNode
*/
searchForBestTestSingle();
if (m_search_bestSplitter == null) return; // handle empty instances
/**
* 根據m_search_bestSplitter生成兩個PreditionNode子節點。
*/
for (int i=0; i<2; i++) {
Instances posInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
Instances negInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
double predictionValue = calcPredictionValue(posInstances, negInstances);
PredictionNode newPredictor = new PredictionNode(predictionValue);
updateWeights(posInstances, negInstances, predictionValue);
m_search_bestSplitter.setChildForBranch(i, newPredictor);
}
/**
* 將生成的m_search_bestSplitter插入ADTree.這裏 m_search_bestInsertionNode的作用體現了.
* m_search_bestInsertionNode保存m_search_bestSplitter的插入點。保證即使是迭代也不會丟失插入點
*/
m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
// free memory
m_search_bestPathPosInstances = null;
m_search_bestPathNegInstances = null;
m_search_bestSplitter = null;
}
boost()算法總結起來就是:計算最小的Z()(Z()的值是需要迭代的),根據最小的Z生成SplitNode及對應的兩個PreditionNode。那麼接下來就看searchForBestTestSingle()
注:這裏我還沒有看處理連續數值型的數據,只看了離散型的。等看了數值型的再補充。
private void searchForBestTestSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
// don't investigate pure or empty nodes any further
if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;
// do z-pure cutoff
/**
* 這個公式我在論文裏沒找到...
*/
if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;
/**
* 可以無視這兩句,記錄狀態用的
*/
m_nodesExpanded++;
m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
// evaluate static splitters (nominal)
/**
* 針對每個屬性計算Z-value
*/
for (int i=0; i<m_nominalAttIndices.length; i++)
evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
posInstances, negInstances);
// evaluate dynamic splitters (numeric)
if (m_numericAttIndices.length > 0) {
// merge the two sets of instances into one
Instances allInstances = new Instances(posInstances);
for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
allInstances.add((Instance) e.nextElement());
// use method of finding the optimal Z split-point
for (int i=0; i<m_numericAttIndices.length; i++)
evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
posInstances, negInstances, allInstances);
}
/**
* 嵌套調用的返回點,如果該PredictionNode沒有子節點,則返回,如果有,
* 則繼續計算子節點(是個SplitNode)的兩個分支(兩個PredctionNode)對應的Z-Value.如果子節點還有子節點,就繼續嵌套調用
* 通過這裏可以看出,計算Z-Value是整個樹所有的分支處比較。有可能上面的層比下面的層的Z-value要小,
* 則會導致上一層再增加一個分支,因此會導致多叉樹,也可能會導致對同一個屬性兩次判斷。
*/
if (currentNode.getChildren().size() == 0) return;
// keep searching
switch (m_searchPath) {
case SEARCHPATH_ALL:
goDownAllPathsSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_HEAVIEST:
goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_ZPURE:
goDownZpurePathSingle(currentNode, posInstances, negInstances);
break;
case SEARCHPATH_RANDOM:
goDownRandomPathSingle(currentNode, posInstances, negInstances);
break;
}
}
看論文得知,算法的過程是針對所有的Prediction 以及所有的SplitNode 都計算一遍Z-value,反映到樹上來說,就是對於每一個SplitNode,都計算一遍Z-value。因此當前的最小值Z-value必須要保存爲全局的。這就是上述
protected PredictionNode m_search_bestInsertionNode;
protected Splitter m_search_bestSplitter;
protected double m_search_smallestZ;
的作用。
private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
Instances posInstances, Instances negInstances)
{
double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);
if (indexAndZ[1] < m_search_smallestZ) {
m_search_smallestZ = indexAndZ[1];
m_search_bestInsertionNode = currentNode;
m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
m_search_bestPathPosInstances = posInstances;
m_search_bestPathNegInstances = negInstances;
}
}
這個函數對某個屬性,所有的屬性值都計算一遍Z-value ,findLowestZNominalSplit(posInstances, negInstances, attIndex)就是逐一利用論文所給公式,選取最小的。這裏就不貼源碼了,很簡單的一個函數。
最終,IndexAndZ[0]存放屬性值的索引,IndexAndZ[1]存放最小的Z-value。後面的if判斷針對這個屬性計算的最小Z-value是否小於當前的最小值。如果小就更新。用三個全局變量保存現場。
switch語句對應了四個函數,四個函數不同點在於對於下一層PredtionNode的選取。下面逐一貼出四個函數代碼。
private void goDownAllPathsSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++)
searchForBestTestSingle(split.getChildForBranch(i),
split.instancesDownBranch(i, posInstances),
split.instancesDownBranch(i, negInstances));
}
}
這個最簡單,反映在樹上就是對所有的節點都計算Z-value.(今天先寫到這裏,後面三個函數也很好理解)