Weka 學習 ID3

ID3算法相對簡單,weka的實現也容易理解。首先介紹一下大致算法。算法概述如下。

1.選擇一種度量(ID3選擇的是信息增益),計算每個屬性對於該度量的值。

2.根據結果選擇一個屬性進行分支。

3.如果每個分支全部屬於一個類或者已經沒有候選屬性。則停止,否則對每個分支進行1,2操作。

下面對weka的ID3 class 作介紹,主要涉及到makeTree(Instances data),computeInfoGain(data, att),splitData(Instances data, Attribute att)三個函數。其中makeTree是入口函數,computeInfoGain的作用是計算信息增益,splitData的作用是分支。首先看makeTree函數。

private void makeTree(Instances data) throws Exception {


    // Check if no instances have reached this node.
    if (data.numInstances() == 0) {
      m_Attribute = null;
      m_ClassValue = Instance.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }


    // Compute attribute with maximum information gain.
    double[] infoGains = new double[data.numAttributes()];
    Enumeration attEnum = data.enumerateAttributes();
    /**
     * 對每個屬性計算信息增益
     */
    while (attEnum.hasMoreElements()) {
      Attribute att = (Attribute) attEnum.nextElement();
      infoGains[att.index()] = computeInfoGain(data, att);
    }
    m_Attribute = data.attribute(Utils.maxIndex(infoGains));
    
    // Make leaf if information gain is zero. 
    // Otherwise create successors.
    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
      m_Attribute = null;
      m_Distribution = new double[data.numClasses()];
      Enumeration instEnum = data.enumerateInstances();
      while (instEnum.hasMoreElements()) {
        Instance inst = (Instance) instEnum.nextElement();
        m_Distribution[(int) inst.classValue()]++;
      }
      Utils.normalize(m_Distribution);
      m_ClassValue = Utils.maxIndex(m_Distribution);
      m_ClassAttribute = data.classAttribute();
    } else {
      Instances[] splitData = splitData(data, m_Attribute);
      m_Successors = new Id3[m_Attribute.numValues()];
      /**
       * 這裏對每個分支繼續調用id3.makeTree(instatnces)。
       */
      for (int j = 0; j < m_Attribute.numValues(); j++) {
        m_Successors[j] = new Id3();
        m_Successors[j].makeTree(splitData[j]);
      }
    }
  }


通過註釋,應該不難理解大致過程。這裏需要注意的是 程序裏經常會出現Enumeration,這其實就是現在的Ieratorer,當時jdk版本較低,所以用的Enumeration,忽視掉就好了。

下面是splitData。只是按照類的值進行分支,也很容易理解。

  private double computeInfoGain(Instances data, Attribute att) 
    throws Exception {

    double infoGain = computeEntropy(data);
    Instances[] splitData = splitData(data, att);
    for (int j = 0; j < att.numValues(); j++) {
      if (splitData[j].numInstances() > 0) {
        infoGain -= ((double) splitData[j].numInstances() /
                     (double) data.numInstances()) *
          computeEntropy(splitData[j]);
      }
    }
    return infoGain;
  }


至於 computeInfoGain對照公式就很容易理解了。這裏只貼出代碼

  private double computeInfoGain(Instances data, Attribute att) 
    throws Exception {

    double infoGain = computeEntropy(data);
    Instances[] splitData = splitData(data, att);
    for (int j = 0; j < att.numValues(); j++) {
      if (splitData[j].numInstances() > 0) {
        infoGain -= ((double) splitData[j].numInstances() /
                     (double) data.numInstances()) *
          computeEntropy(splitData[j]);
      }
    }
    return infoGain;
  }


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章