Weka學習 :ADTree

ADTree中主要有兩種節點,一種是PreditionNode,一種是SplitNode。weka實現中就對應定義了這兩個數據結構。

	public class PredictionNode
	{	
	double value;
	FastVector children;
	}



value存 a或者b(具體含義請看論文)。children存SplitNode.

	public abstract class Splitter
	{
    	public int orderAdded;
	/**
	**	還有其他一些抽象函數,這裏暫且省略
	**/
	}


作者實現了Splitter的兩個子類,分別是TwoWayNominalSplit和TwoWayNumericSplit。具體的結構等到用的時候再說。

主類自然是ADTree。主函數入口

public class ADTree
{
	protected Instances m_trainInstances;

  /** The root of the tree */
  protected PredictionNode m_root = null;

  /** The number of the last splitter added to the tree */
  protected int m_lastAddedSplitNum = 0;

  /** An array containing the indices to the numeric attributes in the data */
  protected int[] m_numericAttIndices;

  /** An array containing the indices to the nominal attributes in the data */
  protected int[] m_nominalAttIndices;

  /** The total weight of the instances - used to speed Z calculations */
  protected double m_trainTotalWeight;

  /** The training instances with positive class - referencing the training dataset */
  protected ReferenceInstances m_posTrainInstances;

  /** The training instances with negative class - referencing the training dataset */
  protected ReferenceInstances m_negTrainInstances;

  /** The best node to insert under, as found so far by the latest search */
  protected PredictionNode m_search_bestInsertionNode;

  /** The best splitter to insert, as found so far by the latest search */
  protected Splitter m_search_bestSplitter;

  /** The smallest Z value found so far by the latest search */
  protected double m_search_smallestZ;

  /** The positive instances that apply to the best path found so far */
  protected Instances m_search_bestPathPosInstances;

  /** The negative instances that apply to the best path found so far */
  protected Instances m_search_bestPathNegInstances;
  
}
這裏列出最重要的成員變量,作者都給瞭解釋,讀者可先熟悉一下,等分析函數的時候再做說明。

下面是主函數入口

 public void buildClassifier(Instances instances) throws Exception {

  
	/**
	**初始化m_trainInstances,m_posTrainInstances, 
	m_negTrainInstances,m_root,m_numericAttIndices,
	m_nominalAttIndices,m_trainTotalWeight
	**/
    initClassifier(instances);

    // 典型的AdaBoost 算法
    for (int T = 0; T < m_boostingIterations; T++) boost();


  }
  
initClassifier(instances)函數不打算貼出來了,完成的功能就是註釋裏的功能,讀者可以自行查閱一下。這裏分析核心函數 boost()
  public void boost() throws Exception {

    if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
      throw new Exception("Trying to boost with no training data");

    /**
     * 迭代入口。真正選取splitNode的地方。即賦值成員變量m_search_bestSplitter,m_search_bestInsertionNode

     */
    searchForBestTestSingle();

    if (m_search_bestSplitter == null) return; // handle empty instances

    /**
     * 根據m_search_bestSplitter生成兩個PreditionNode子節點。
     */
    for (int i=0; i<2; i++) {
      Instances posInstances =
	m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
      Instances negInstances =
	m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
      double predictionValue = calcPredictionValue(posInstances, negInstances);
      PredictionNode newPredictor = new PredictionNode(predictionValue);
      updateWeights(posInstances, negInstances, predictionValue);
      m_search_bestSplitter.setChildForBranch(i, newPredictor);
    }

    /**
     * 將生成的m_search_bestSplitter插入ADTree.這裏 m_search_bestInsertionNode的作用體現了.
     * m_search_bestInsertionNode保存m_search_bestSplitter的插入點。保證即使是迭代也不會丟失插入點
     */
    m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);

    // free memory
    m_search_bestPathPosInstances = null;
    m_search_bestPathNegInstances = null;
    m_search_bestSplitter = null;
  }
boost()算法總結起來就是:計算最小的Z()(Z()的值是需要迭代的),根據最小的Z生成SplitNode及對應的兩個PreditionNode。那麼接下來就看searchForBestTestSingle()

注:這裏我還沒有看處理連續數值型的數據,只看了離散型的。等看了數值型的再補充。

private void searchForBestTestSingle(PredictionNode currentNode,
				       Instances posInstances, Instances negInstances)
    throws Exception {

    // don't investigate pure or empty nodes any further

    if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;

    // do z-pure cutoff
	  /**
	   * 這個公式我在論文裏沒找到...
	   */
    if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;

    /**
     * 可以無視這兩句,記錄狀態用的
     */
    m_nodesExpanded++;
    m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();

    // evaluate static splitters (nominal)
    /**
     * 針對每個屬性計算Z-value
     */
    for (int i=0; i<m_nominalAttIndices.length; i++)
      evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
				 posInstances, negInstances);

    // evaluate dynamic splitters (numeric)
    if (m_numericAttIndices.length > 0) {

      // merge the two sets of instances into one
      Instances allInstances = new Instances(posInstances);
      for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
	allInstances.add((Instance) e.nextElement());
    
      // use method of finding the optimal Z split-point
      for (int i=0; i<m_numericAttIndices.length; i++)
	evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
				   posInstances, negInstances, allInstances);
    }
    /**
     * 嵌套調用的返回點,如果該PredictionNode沒有子節點,則返回,如果有,
     * 則繼續計算子節點(是個SplitNode)的兩個分支(兩個PredctionNode)對應的Z-Value.如果子節點還有子節點,就繼續嵌套調用
     * 通過這裏可以看出,計算Z-Value是整個樹所有的分支處比較。有可能上面的層比下面的層的Z-value要小,
     * 則會導致上一層再增加一個分支,因此會導致多叉樹,也可能會導致對同一個屬性兩次判斷。
     */
    if (currentNode.getChildren().size() == 0) return;

    // keep searching
    switch (m_searchPath) {
    case SEARCHPATH_ALL:
      goDownAllPathsSingle(currentNode, posInstances, negInstances);
      break;
    case SEARCHPATH_HEAVIEST: 
      goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
      break;
    case SEARCHPATH_ZPURE: 
      goDownZpurePathSingle(currentNode, posInstances, negInstances);
      break;
    case SEARCHPATH_RANDOM: 
      goDownRandomPathSingle(currentNode, posInstances, negInstances);
      break;
    }
  }
看論文得知,算法的過程是針對所有的Prediction 以及所有的SplitNode 都計算一遍Z-value,反映到樹上來說,就是對於每一個SplitNode,都計算一遍Z-value。因此當前的最小值Z-value必須要保存爲全局的。這就是上述 

 protected PredictionNode m_search_bestInsertionNode;
  protected Splitter m_search_bestSplitter;
  protected double m_search_smallestZ;

的作用。

 private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
					  Instances posInstances, Instances negInstances)
  {
    
    double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);

    if (indexAndZ[1] < m_search_smallestZ) {
      m_search_smallestZ = indexAndZ[1];
      m_search_bestInsertionNode = currentNode;
      m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
      m_search_bestPathPosInstances = posInstances;
      m_search_bestPathNegInstances = negInstances;
    }
  }
這個函數對某個屬性,所有的屬性值都計算一遍Z-value ,findLowestZNominalSplit(posInstances, negInstances, attIndex)就是逐一利用論文所給公式,選取最小的。這裏就不貼源碼了,很簡單的一個函數。

最終,IndexAndZ[0]存放屬性值的索引,IndexAndZ[1]存放最小的Z-value。後面的if判斷針對這個屬性計算的最小Z-value是否小於當前的最小值。如果小就更新。用三個全局變量保存現場。

switch語句對應了四個函數,四個函數不同點在於對於下一層PredtionNode的選取。下面逐一貼出四個函數代碼。

 private void goDownAllPathsSingle(PredictionNode currentNode,
				    Instances posInstances, Instances negInstances)
    throws Exception {

    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
      Splitter split = (Splitter) e.nextElement();
      for (int i=0; i<split.getNumOfBranches(); i++)
	searchForBestTestSingle(split.getChildForBranch(i),
				split.instancesDownBranch(i, posInstances),
				split.instancesDownBranch(i, negInstances));
    }
  }
這個最簡單,反映在樹上就是對所有的節點都計算Z-value.(今天先寫到這裏,後面三個函數也很好理解)


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章