決策樹C4.5算法Java代碼

 

 

-

加入菜鳥學習網,獲得珍藏資源

Java代碼

  1. 數據挖掘中決策樹C4.5預測算法實現(半成品,還要寫規則後剪枝及對非離散數據信息增益計算)

Java代碼

  1. package org.struct.decisiontree;  
  2. import java.util.ArrayList;  
  3. import java.util.Arrays;  
  4. import java.util.List;  
  5. import java.util.TreeSet;  
  6. /**
  7. * @author Leon.Chen
  8. */
  9. public class DecisionTreeBaseC4p5 {  
  10. /**
  11.      * root node
  12.      */
  13. private DecisionTreeNode root;  
  14. /**
  15.      * visableArray
  16.      */
  17. private boolean[] visable;  
  18. private static final int NOT_FOUND = -1;  
  19. private static final int DATA_START_LINE = 1;  
  20. private Object[] trainingArray;  
  21. private String[] columnHeaderArray;  
  22. /**
  23.      * forecast node index
  24.      */
  25. private int nodeIndex;  
  26. /**
  27.      * @param args
  28.      */
  29. @SuppressWarnings("boxing")  
  30. public static void main(String[] args) {  
  31.         Object[] array = new Object[] {  
  32. new String[] { "age",          "income",   "student", "credit_rating", "buys_computer" },  
  33. new String[] { "youth",        "high",     "no",      "fair",          "no"  },  
  34. new String[] { "youth",        "high",     "no",      "excellent",     "no"  },  
  35. new String[] { "middle_aged",  "high",     "no",      "fair",          "yes" },  
  36. new String[] { "senior",       "medium",   "no",      "fair",          "yes" },  
  37. new String[] { "senior",       "low",      "yes",     "fair",          "yes" },  
  38. new String[] { "senior",       "low",      "yes",     "excellent",     "no"  },  
  39. new String[] { "middle_aged",  "low",      "yes",     "excellent",     "yes" },  
  40. new String[] { "youth",        "medium",   "no",      "fair",          "no"  },  
  41. new String[] { "youth",        "low",      "yes",     "fair",          "yes" },  
  42. new String[] { "senior",       "medium",   "yes",     "fair",          "yes" },  
  43. new String[] { "youth",        "medium",   "yes",     "excellent",     "yes" },  
  44. new String[] { "middle_aged",  "medium",   "no",      "excellent",     "yes" },  
  45. new String[] { "middle_aged",  "high",     "yes",     "fair",          "yes" },  
  46. new String[] { "senior",       "medium",   "no",      "excellent",     "no"  },  
  47.         };  
  48.         DecisionTreeBaseC4p5 tree = new DecisionTreeBaseC4p5();  
  49.         tree.create(array, 4);  
  50.         System.out.println("===============END PRINT TREE===============");  
  51.         System.out.println("===============DECISION RESULT===============");  
  52. //tree.forecast(printData, tree.root);
  53.     }  
  54. /**
  55.      * @param printData
  56.      * @param node
  57.      */
  58. public void forecast(String[] printData, DecisionTreeNode node) {  
  59. int index = getColumnHeaderIndexByName(node.nodeName);  
  60. if (index == NOT_FOUND) {  
  61.             System.out.println(node.nodeName);  
  62.         }  
  63.         DecisionTreeNode[] childs = node.childNodesArray;  
  64. for (int i = 0; i < childs.length; i++) {  
  65. if (childs[i] != null) {  
  66. if (childs[i].parentArrtibute.equals(printData[index])) {  
  67.                     forecast(printData, childs[i]);  
  68.                 }  
  69.             }  
  70.         }  
  71.     }  
  72. /**
  73.      * @param array
  74.      * @param index
  75.      */
  76. public void create(Object[] array, int index) {  
  77. this.trainingArray = Arrays.copyOfRange(array, DATA_START_LINE,  
  78.                 array.length);  
  79.         init(array, index);  
  80.         createDecisionTree(this.trainingArray);  
  81.         printDecisionTree(root);  
  82.     }  
  83. /**
  84.      * @param array
  85.      * @return Object[]
  86.      */
  87. @SuppressWarnings("boxing")  
  88. public Object[] getMaxGain(Object[] array) {  
  89.         Object[] result = new Object[2];  
  90. double gain = 0;  
  91. int index = -1;  
  92. for (int i = 0; i < visable.length; i++) {  
  93. if (!visable[i]) {  
  94. //TODO ID3 change to C4.5
  95. double value = gainRatio(array, i, this.nodeIndex);  
  96.                 System.out.println(value);  
  97. if (gain < value) {  
  98.                     gain = value;  
  99.                     index = i;  
  100.                 }  
  101.             }  
  102.         }  
  103.         result[0] = gain;  
  104.         result[1] = index;  
  105. // TODO throws can't forecast this model exception
  106. if (index != -1) {  
  107.             visable[index] = true;  
  108.         }  
  109. return result;  
  110.     }  
  111. /**
  112.      * @param array
  113.      */
  114. public void createDecisionTree(Object[] array) {  
  115.         Object[] maxgain = getMaxGain(array);  
  116. if (root == null) {  
  117.             root = new DecisionTreeNode();  
  118.             root.parentNode = null;  
  119.             root.parentArrtibute = null;  
  120.             root.arrtibutesArray = getArrtibutesArray(((Integer) maxgain[1])  
  121.                     .intValue());  
  122.             root.nodeName = getColumnHeaderNameByIndex(((Integer) maxgain[1])  
  123.                     .intValue());  
  124.             root.childNodesArray = new DecisionTreeNode[root.arrtibutesArray.length];  
  125.             insertDecisionTree(array, root);  
  126.         }  
  127.     }  
  128. /**
  129.      * @param array
  130.      * @param parentNode
  131.      */
  132. public void insertDecisionTree(Object[] array, DecisionTreeNode parentNode) {  
  133.         String[] arrtibutes = parentNode.arrtibutesArray;  
  134. for (int i = 0; i < arrtibutes.length; i++) {  
  135.             Object[] pickArray = pickUpAndCreateSubArray(array, arrtibutes[i],  
  136.                     getColumnHeaderIndexByName(parentNode.nodeName));  
  137.             Object[] info = getMaxGain(pickArray);  
  138. double gain = ((Double) info[0]).doubleValue();  
  139. if (gain != 0) {  
  140. int index = ((Integer) info[1]).intValue();  
  141.                 DecisionTreeNode currentNode = new DecisionTreeNode();  
  142.                 currentNode.parentNode = parentNode;  
  143.                 currentNode.parentArrtibute = arrtibutes[i];  
  144.                 currentNode.arrtibutesArray = getArrtibutesArray(index);  
  145.                 currentNode.nodeName = getColumnHeaderNameByIndex(index);  
  146.                 currentNode.childNodesArray = new DecisionTreeNode[currentNode.arrtibutesArray.length];  
  147.                 parentNode.childNodesArray[i] = currentNode;  
  148.                 insertDecisionTree(pickArray, currentNode);  
  149.             } else {  
  150.                 DecisionTreeNode leafNode = new DecisionTreeNode();  
  151.                 leafNode.parentNode = parentNode;  
  152.                 leafNode.parentArrtibute = arrtibutes[i];  
  153.                 leafNode.arrtibutesArray = new String[0];  
  154.                 leafNode.nodeName = getLeafNodeName(pickArray,this.nodeIndex);  
  155.                 leafNode.childNodesArray = new DecisionTreeNode[0];  
  156.                 parentNode.childNodesArray[i] = leafNode;  
  157.             }  
  158.         }  
  159.     }  
  160. /**
  161.      * @param node
  162.      */
  163. public void printDecisionTree(DecisionTreeNode node) {  
  164.         System.out.println(node.nodeName);  
  165.         DecisionTreeNode[] childs = node.childNodesArray;  
  166. for (int i = 0; i < childs.length; i++) {  
  167. if (childs[i] != null) {  
  168.                 System.out.println(childs[i].parentArrtibute);  
  169.                 printDecisionTree(childs[i]);  
  170.             }  
  171.         }  
  172.     }  
  173. /**
  174.      * init data
  175.      * 
  176.      * @param dataArray
  177.      * @param index
  178.      */
  179. public void init(Object[] dataArray, int index) {  
  180. this.nodeIndex = index;  
  181. // init data
  182. this.columnHeaderArray = (String[]) dataArray[0];  
  183.         visable = new boolean[((String[]) dataArray[0]).length];  
  184. for (int i = 0; i < visable.length; i++) {  
  185. if (i == index) {  
  186.                 visable[i] = true;  
  187.             } else {  
  188.                 visable[i] = false;  
  189.             }  
  190.         }  
  191.     }  
  192. /**
  193.      * @param array
  194.      * @param arrtibute
  195.      * @param index
  196.      * @return Object[]
  197.      */
  198. public Object[] pickUpAndCreateSubArray(Object[] array, String arrtibute,  
  199. int index) {  
  200.         List list = new ArrayList();  
  201. for (int i = 0; i < array.length; i++) {  
  202.             String[] strs = (String[]) array[i];  
  203. if (strs[index].equals(arrtibute)) {  
  204.                 list.add(strs);  
  205.             }  
  206.         }  
  207. return list.toArray();  
  208.     }  
  209. /**
  210.      * gain(A)
  211.      * 
  212.      * @param array
  213.      * @param index
  214.      * @return double
  215.      */
  216. public double gain(Object[] array, int index, int nodeIndex) {  
  217. int[] counts = separateToSameValueArrays(array, nodeIndex);  
  218.         String[] arrtibutes = getArrtibutesArray(index);  
  219. double infoD = infoD(array, counts);  
  220. double infoaD = infoaD(array, index, nodeIndex, arrtibutes);  
  221. return infoD - infoaD;  
  222.     }  
  223. /**
  224.      * @param array
  225.      * @param nodeIndex
  226.      * @return
  227.      */
  228. public int[] separateToSameValueArrays(Object[] array, int nodeIndex) {  
  229.         String[] arrti = getArrtibutesArray(nodeIndex);  
  230. int[] counts = new int[arrti.length];  
  231. for (int i = 0; i < counts.length; i++) {  
  232.             counts[i] = 0;  
  233.         }  
  234. for (int i = 0; i < array.length; i++) {  
  235.             String[] strs = (String[]) array[i];  
  236. for (int j = 0; j < arrti.length; j++) {  
  237. if (strs[nodeIndex].equals(arrti[j])) {  
  238.                     counts[j]++;  
  239.                 }  
  240.             }  
  241.         }  
  242. return counts;  
  243.     }  
  244. /**
  245.      * gainRatio = gain(A)/splitInfo(A)
  246.      * 
  247.      * @param array
  248.      * @param index
  249.      * @param nodeIndex
  250.      * @return
  251.      */
  252. public double gainRatio(Object[] array,int index,int nodeIndex){  
  253. double gain = gain(array,index,nodeIndex);  
  254. int[] counts = separateToSameValueArrays(array, index);  
  255. double splitInfo = splitInfoaD(array,counts);  
  256. if(splitInfo != 0){  
  257. return gain/splitInfo;  
  258.         }  
  259. return 0;  
  260.     }  
  261. /**
  262.      * infoD = -E(pi*log2 pi)
  263.      * 
  264.      * @param array
  265.      * @param counts
  266.      * @return
  267.      */
  268. public, ,  double infoD(Object[] array, int[] counts) {  
  269. double infoD = 0;  
  270. for (int i = 0; i < counts.length; i++) {  
  271.             infoD += DecisionTreeUtil.info(counts[i], array.length);  
  272.         }  
  273. return infoD;  
  274.     }  
  275. &n, , , bsp;
  276. /**
  277.      * splitInfoaD = -E|Dj|/|D|*log2(|Dj|/|D|)
  278.      * 
  279.      * @param array
  280.      * @param counts
  281.      * @return
  282.      */
  283. public double splitInfoaD(Object[] array, int[] counts) {  
  284. return infoD(array, counts);  
  285.     }  
  286. /**
  287.      * infoaD = E(|Dj| / |D|) * info(Dj)
  288.      * 
  289.      * @param array
  290.      * @param index
  291.      * @param arrtibutes
  292.      * @return
  293.      */
  294. public double infoaD(Object[] array, int index, int nodeIndex,  
  295.             String[] arrtibutes) {  
  296. double sv_total = 0;  
  297. for (int i = 0; i < arrtibutes.length; i++) {  
  298.             sv_total += infoDj(array, index, nodeIndex, arrtibutes[i],  
  299.                     array.length);  
  300.         }  
  301. return sv_total;  
  302.     }  
  303. /**
  304.      * ((|Dj| / |D|) * Info(Dj))
  305.      * 
  306.      * @param array
  307.      * @param index
  308.      * @param arrtibute
  309.      * @param allTotal
  310.      * @return double
  311.      */
  312. public double infoDj(Object[] array, int index, int nodeIndex,  
  313.             String arrtibute, int allTotal) {  
  314.         String[] arrtibutes = getArrtibutesArray(nodeIndex);  
  315. int[] counts = new int[arrtibutes.length];  
  316. for (int i = 0; i < counts.length; i++) {  
  317.             counts[i] = 0;  
  318.         }  
  319. for (int i = 0; i < array.length; i++) {  
  320.             String[] strs = (String[]) array[i];  
  321. if (strs[index].equals(arrtibute)) {  
  322. for (int k = 0; k < arrtibutes.length; k++) {  
  323. if (strs[nodeIndex].equals(arrtibutes[k])) {  
  324.                         counts[k]++;  
  325.                     }  
  326.                 }  
  327.             }  
  328.         }  
  329. int total = 0;  
  330. double infoDj = 0;  
  331. for (int i = 0; i < counts.length; i++) {  
  332.             total += counts[i];  
  333.         }  
  334. for (int i = 0; i < counts.length; i++) {  
  335.             infoDj += DecisionTreeUtil.info(counts[i], total);  
  336.         }  
  337. return DecisionTreeUtil.getPi(total, allTotal) * infoDj;  
  338.     }  
  339. /**
  340.      * @param index
  341.      * @return String[]
  342.      */
  343. @SuppressWarnings("unchecked")  
  344. public String[] getArrtibutesArray(int index) {  
  345.         TreeSet set = new TreeSet(new SequenceComparator());  
  346. for (int i = 0; i < trainingArray.length; i++) {  
  347.             String[] strs = (String[]) trainingArray[i];  
  348.             set.add(strs[index]);  
  349.         }  
  350.         String[] result = new String[set.size()];  
  351. return set.toArray(result);  
  352.     }  
  353. /**
  354.      * @param index
  355.      * @return String
  356.      */
  357. public String getColumnHeaderNameByIndex(int index) {  
  358. for (int i = 0; i < columnHeaderArray.length; i++) {  
  359. if (i == index) {  
  360. return columnHeaderArray[i];  
  361.             }  
  362.         }  
  363. return null;  
  364.     }  
  365. /**
  366.      * @param array
  367.      * @return String
  368.      */
  369. public String getLeafNodeName(Object[] array,int nodeIndex) {  
  370. if (array != null && array.length > 0) {  
  371.             String[] strs = (String[]) array[0];  
  372. return strs[nodeIndex];  
  373.         }  
  374. return null;  
  375.     }  
  376. /**
  377.      * @param name
  378.      * @return int
  379.      */
  380. public int getColumnHeaderIndexByName(String name) {  
  381. for (int i = 0; i < columnHeaderArray.length; i++) {  
  382. if (name.equals(columnHeaderArray[i])) {  
  383. return i;  
  384.             }  
  385.         }  
  386. return NOT_FOUND;  
  387.     }  
package org.struct.decisiontree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeSet;

/**
 * @author Leon.Chen
 */
public class DecisionTreeBaseC4p5 {
	
	/**
	 * root node
	 */
	private DecisionTreeNode root;

	/**
	 * visableArray
	 */
	private boolean[] visable;

	private static final int NOT_FOUND = -1;

	private static final int DATA_START_LINE = 1;

	private Object[] trainingArray;

	private String[] columnHeaderArray;

	/**
	 * forecast node index
	 */
	private int nodeIndex;

	/**
	 * @param args
	 */
	@SuppressWarnings("boxing")
	public static void main(String[] args) {
		Object[] array = new Object[] {
				new String[] { "age",          "income",   "student", "credit_rating", "buys_computer" },
				new String[] { "youth",        "high",     "no",      "fair",          "no"  },
				new String[] { "youth",        "high",     "no",      "excellent",     "no"  },
				new String[] { "middle_aged",  "high",     "no",      "fair",          "yes" },
				new String[] { "senior",       "medium",   "no",      "fair",          "yes" },
				new String[] { "senior",       "low",      "yes",     "fair",          "yes" },
				new String[] { "senior",       "low",      "yes",     "excellent",     "no"  },
				new String[] { "middle_aged",  "low",      "yes",     "excellent",     "yes" },
				new String[] { "youth",        "medium",   "no",      "fair",          "no"  },
				new String[] { "youth",        "low",      "yes",     "fair",          "yes" },
				new String[] { "senior",       "medium",   "yes",     "fair",          "yes" },
				new String[] { "youth",        "medium",   "yes",     "excellent",     "yes" },
				new String[] { "middle_aged",  "medium",   "no",      "excellent",     "yes" },
				new String[] { "middle_aged",  "high",     "yes",     "fair",          "yes" },
				new String[] { "senior",       "medium",   "no",      "excellent",     "no"  },
		};

		DecisionTreeBaseC4p5 tree = new DecisionTreeBaseC4p5();
		tree.create(array, 4);
		System.out.println("===============END PRINT TREE===============");
		System.out.println("===============DECISION RESULT===============");
		//tree.forecast(printData, tree.root);
	}

	/**
	 * @param printData
	 * @param node
	 */
	public void forecast(String[] printData, DecisionTreeNode node) {
		int index = getColumnHeaderIndexByName(node.nodeName);
		if (index == NOT_FOUND) {
			System.out.println(node.nodeName);
		}
		DecisionTreeNode[] childs = node.childNodesArray;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				if (childs[i].parentArrtibute.equals(printData[index])) {
					forecast(printData, childs[i]);
				}
			}
		}
	}

	/**
	 * @param array
	 * @param index
	 */
	public void create(Object[] array, int index) {
		this.trainingArray = Arrays.copyOfRange(array, DATA_START_LINE,
				array.length);
		init(array, index);
		createDecisionTree(this.trainingArray);
		printDecisionTree(root);
	}

	/**
	 * @param array
	 * @return Object[]
	 */
	@SuppressWarnings("boxing")
	public Object[] getMaxGain(Object[] array) {
		Object[] result = new Object[2];
		double gain = 0;
		int index = -1;

		for (int i = 0; i < visable.length; i++) {
			if (!visable[i]) {
				//TODO ID3 change to C4.5
				double value = gainRatio(array, i, this.nodeIndex);
				System.out.println(value);
				if (gain < value) {
					gain = value;
					index = i;
				}
			}
		}
		result[0] = gain;
		result[1] = index;
		// TODO throws can't forecast this model exception
		if (index != -1) {
			visable[index] = true;
		}
		return result;
	}

	/**
	 * @param array
	 */
	public void createDecisionTree(Object[] array) {
		Object[] maxgain = getMaxGain(array);
		if (root == null) {
			root = new DecisionTreeNode();
			root.parentNode = null;
			root.parentArrtibute = null;
			root.arrtibutesArray = getArrtibutesArray(((Integer) maxgain[1])
					.intValue());
			root.nodeName = getColumnHeaderNameByIndex(((Integer) maxgain[1])
					.intValue());
			root.childNodesArray = new DecisionTreeNode[root.arrtibutesArray.length];
			insertDecisionTree(array, root);
		}
	}

	/**
	 * @param array
	 * @param parentNode
	 */
	public void insertDecisionTree(Object[] array, DecisionTreeNode parentNode) {
		String[] arrtibutes = parentNode.arrtibutesArray;
		for (int i = 0; i < arrtibutes.length; i++) {
			Object[] pickArray = pickUpAndCreateSubArray(array, arrtibutes[i],
					getColumnHeaderIndexByName(parentNode.nodeName));
			Object[] info = getMaxGain(pickArray);
			double gain = ((Double) info[0]).doubleValue();
			if (gain != 0) {
				int index = ((Integer) info[1]).intValue();
				DecisionTreeNode currentNode = new DecisionTreeNode();
				currentNode.parentNode = parentNode;
				currentNode.parentArrtibute = arrtibutes[i];
				currentNode.arrtibutesArray = getArrtibutesArray(index);
				currentNode.nodeName = getColumnHeaderNameByIndex(index);
				currentNode.childNodesArray = new DecisionTreeNode[currentNode.arrtibutesArray.length];
				parentNode.childNodesArray[i] = currentNode;
				insertDecisionTree(pickArray, currentNode);
			} else {
				DecisionTreeNode leafNode = new DecisionTreeNode();
				leafNode.parentNode = parentNode;
				leafNode.parentArrtibute = arrtibutes[i];
				leafNode.arrtibutesArray = new String[0];
				leafNode.nodeName = getLeafNodeName(pickArray,this.nodeIndex);
				leafNode.childNodesArray = new DecisionTreeNode[0];
				parentNode.childNodesArray[i] = leafNode;
			}
		}
	}

	/**
	 * @param node
	 */
	public void printDecisionTree(DecisionTreeNode node) {
		System.out.println(node.nodeName);
		DecisionTreeNode[] childs = node.childNodesArray;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				System.out.println(childs[i].parentArrtibute);
				printDecisionTree(childs[i]);
			}
		}
	}

	/**
	 * init data
	 * 
	 * @param dataArray
	 * @param index
	 */
	public void init(Object[] dataArray, int index) {
		this.nodeIndex = index;
		// init data
		this.columnHeaderArray = (String[]) dataArray[0];
		visable = new boolean[((String[]) dataArray[0]).length];
		for (int i = 0; i < visable.length; i++) {
			if (i == index) {
				visable[i] = true;
			} else {
				visable[i] = false;
			}
		}
	}

	/**
	 * @param array
	 * @param arrtibute
	 * @param index
	 * @return Object[]
	 */
	public Object[] pickUpAndCreateSubArray(Object[] array, String arrtibute,
			int index) {
		List list = new ArrayList();
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				list.add(strs);
			}
		}
		return list.toArray();
	}

	/**
	 * gain(A)
	 * 
	 * @param array
	 * @param index
	 * @return double
	 */
	public double gain(Object[] array, int index, int nodeIndex) {
		int[] counts = separateToSameValueArrays(array, nodeIndex);
		String[] arrtibutes = getArrtibutesArray(index);
		double infoD = infoD(array, counts);
		double infoaD = infoaD(array, index, nodeIndex, arrtibutes);
		return infoD - infoaD;
	}

	/**
	 * @param array
	 * @param nodeIndex
	 * @return
	 */
	public int[] separateToSameValueArrays(Object[] array, int nodeIndex) {
		String[] arrti = getArrtibutesArray(nodeIndex);
		int[] counts = new int[arrti.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			for (int j = 0; j < arrti.length; j++) {
				if (strs[nodeIndex].equals(arrti[j])) {
					counts[j]++;
				}
			}
		}
		return counts;
	}
	
	/**
	 * gainRatio = gain(A)/splitInfo(A)
	 * 
	 * @param array
	 * @param index
	 * @param nodeIndex
	 * @return
	 */
	public double gainRatio(Object[] array,int index,int nodeIndex){
		double gain = gain(array,index,nodeIndex);
		int[] counts = separateToSameValueArrays(array, index);
		double splitInfo = splitInfoaD(array,counts);
		if(splitInfo != 0){
			return gain/splitInfo;
		}
		return 0;
	}

	/**
	 * infoD = -E(pi*log2 pi)
	 * 
	 * @param array
	 * @param counts
	 * @return
	 */
	public double infoD(Object[] array, int[] counts) {
		double infoD = 0;
		for (int i = 0; i < counts.length; i++) {
			infoD += DecisionTreeUtil.info(counts[i], array.length);
		}
		return infoD;
	}

	/**
	 * splitInfoaD = -E|Dj|/|D|*log2(|Dj|/|D|)
	 * 
	 * @param array
	 * @param counts
	 * @return
	 */
	public double splitInfoaD(Object[] array, int[] counts) {
		return infoD(array, counts);
	}

	/**
	 * infoaD = E(|Dj| / |D|) * info(Dj)
	 * 
	 * @param array
	 * @param index
	 * @param arrtibutes
	 * @return
	 */
	public double infoaD(Object[] array, int index, int nodeIndex,
			String[] arrtibutes) {
		double sv_total = 0;
		for (int i = 0; i < arrtibutes.length; i++) {
			sv_total += infoDj(array, index, nodeIndex, arrtibutes[i],
					array.length);
		}
		return sv_total;
	}

	/**
	 * ((|Dj| / |D|) * Info(Dj))
	 * 
	 * @param array
	 * @param index
	 * @param arrtibute
	 * @param allTotal
	 * @return double
	 */
	public double infoDj(Object[] array, int index, int nodeIndex,
			String arrtibute, int allTotal) {
		String[] arrtibutes = getArrtibutesArray(nodeIndex);
		int[] counts = new int[arrtibutes.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}

		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				for (int k = 0; k < arrtibutes.length; k++) {
					if (strs[nodeIndex].equals(arrtibutes[k])) {
						counts[k]++;
					}
				}
			}
		}

		int total = 0;
		double infoDj = 0;
		for (int i = 0; i < counts.length; i++) {
			total += counts[i];
		}
		for (int i = 0; i < counts.length; i++) {
			infoDj += DecisionTreeUtil.info(counts[i], total);
		}
		return DecisionTreeUtil.getPi(total, allTotal) * infoDj;
	}

	/**
	 * @param index
	 * @return String[]
	 */
	@SuppressWarnings("unchecked")
	public String[] getArrtibutesArray(int index) {
		TreeSet set = new TreeSet(new SequenceComparator());
		for (int i = 0; i < trainingArray.length; i++) {
			String[] strs = (String[]) trainingArray[i];
			set.add(strs[index]);
		}
		String[] result = new String[set.size()];
		return set.toArray(result);
	}

	/**
	 * @param index
	 * @return String
	 */
	public String getColumnHeaderNameByIndex(int index) {
		for (int i = 0; i < columnHeaderArray.length; i++) {
			if (i == index) {
				return columnHeaderArray[i];
			}
		}
		return null;
	}

	/**
	 * @param array
	 * @return String
	 */
	public String getLeafNodeName(Object[] array,int nodeIndex) {
		if (array != null && array.length > 0) {
			String[] strs = (String[]) array[0];
			return strs[nodeIndex];
		}
		return null;
	}

	/**
	 * @param name
	 * @return int
	 */
	public int getColumnHeaderIndexByName(String name) {
		for (int i = 0; i < columnHeaderArray.length; i++) {
			if (name.equals(columnHeaderArray[i])) {
				return i;
			}
		}
		return NOT_FOUND;
	}
}

Java代碼

  1. package org.struct.decisiontree;  
  2. /**
  3. * @author Leon.Chen
  4. */
  5. public class DecisionTreeNode {  
  6.     DecisionTreeNode parentNode;  
  7.     String parentArrtibute;  
  8.     String nodeName;  
  9.     String[] arrtibutesArray;  
  10.     DecisionTreeNode[] childNodesArray;  
package org.struct.decisiontree;

/**
 * @author Leon.Chen
 */
public class DecisionTreeNode {

	DecisionTreeNode parentNode;

	String parentArrtibute;

	String nodeName;

	String[] arrtibutesArray;

	DecisionTreeNode[] childNodesArray;

}

Java代碼

  1. package org.struct.decisiontree;  
  2. /**
  3. * @author Leon.Chen
  4. */
  5. public class DecisionTreeUtil {  
  6. /**
  7.      * entropy:Info(T)=(i=1...k)pi*log(2)pi
  8.      * 
  9.      * @param x
  10.      * @param total
  11.      * @return double
  12.      */
  13. public static double info(int x, int total) {  
  14. if (x == 0) {  
  15. return 0;  
  16.         }  
  17. double x_pi = getPi(x, total);  
  18. return -(x_pi * logYBase2(x_pi));  
  19.     }  
  20. /**
  21.      * log2y
  22.      * 
  23.      * @param y
  24.      * @return double
  25.      */
  26. public static double logYBase2(double y) {  
  27. return Math.log(y) / Math.log(2);  
  28.     }  
  29. /**
  30.      * pi=|C(i,d)|/|D|
  31.      * 
  32.      * @param x
  33.      * @param total
  34.      * @return double
  35.      */
  36. public static double getPi(int x, int total) {  
  37. return x / (double) total;  
  38.     }  
package org.struct.decisiontree;

/**
 * @author Leon.Chen
 */
public class DecisionTreeUtil {

	/**
	 * entropy:Info(T)=(i=1...k)pi*log(2)pi
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double info(int x, int total) {
		if (x == 0) {
			return 0;
		}
		double x_pi = getPi(x, total);
		return -(x_pi * logYBase2(x_pi));
	}

	/**
	 * log2y
	 * 
	 * @param y
	 * @return double
	 */
	public static double logYBase2(double y) {
		return Math.log(y) / Math.log(2);
	}

	/**
	 * pi=|C(i,d)|/|D|
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double getPi(int x, int total) {
		return x / (double) total;
	}

}

Java代碼

  1. package org.struct.decisiontree;  
  2. import java.util.Comparator;  
  3. /**
  4. * @author Leon.Chen
  5. */
  6. @SuppressWarnings("unchecked")  
  7. public class SequenceComparator implements Comparator {  
  8. public int compare(Object o1, Object o2) throws ClassCastException {  
  9.         String str1 = (String) o1;  
  10.         String str2 = (String) o2;  
  11. return str1.compareTo(str2);  
  12.     }  
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章