模式識別與機器學習作業——決策樹

在這裏插入圖片描述

Homework 4

Report:

ID3

在這裏插入圖片描述

(a) (20 points) Build a decision tree based on the this table using ID3ID3 algorithm (Please use the entropy impurity).

在這裏插入圖片描述
在這裏插入圖片描述

在這裏插入圖片描述

在這裏插入圖片描述

在這裏插入圖片描述

The result:(based on ID3ID3

在這裏插入圖片描述

CART

(b) Build a decision tree based on the this table using CARTCART algorithm (Please use the GiniGini impurity).

在這裏插入圖片描述
The result:(based on CartCart

在這裏插入圖片描述

© Compare the results of (a) and (b), and explain the major difference between ID3ID3 and CARTCART.

Generating decision tree pruning
ID3ID3 When the ID3ID3 algorithm generates a decision tree, the feature with the largest information gain and all its possible values are picked up at each layer to divide the data set, so ID3ID3 generation is not necessarily a binary tree. The pruning of ID3ID3 is performed by comparing the changes of the loss function of a branch before and after being pruned
CARTCART When CARTCART spans the tree, iterates through all possible values of each feature, calculates the maximum or minimum GiniGini coefficient (classification) or mean square error (regression) and its value, and divides the data set according to whether it is equal to this value. Therefore, the decision tree derived by CARTCART is a binary tree. The nature of regression tree is also the idea of classification. CARTCART uses a non-fixed ‘regularization parameter’, and gradually increases (or decreases) the value to obtain multiple pruned subtrees, and selects the optimal subtree through cross-validation.

Code:

ID3

import numpy as np
import pandas as pd
from math import log
import matplotlib.pyplot as plt
%matplotlib inline
dataset = pd.read_csv('data.csv')
dataset = dataset.iloc[:,1:]
dataset
Outlook Temperature Humidity Wind PlayTennis
0 Sunny Hot High Weak No
1 Sunny Hot High Strong No
2 Overcast Hot High Weak Yes
3 Rain Mild High Weak Yes
4 Rain Cool Normal Weak Yes
5 Rain Cool Normal Strong No
6 Overcast Cool Normal Strong Yes
7 Sunny Mild High Weak No
8 Sunny Cool Normal Weak Yes
9 Rain Mild Normal Weak Yes
10 Sunny Mild Normal Strong Yes
11 Overcast Mild High Strong Yes
12 Overcast Hot Normal Weak Yes
13 Rain Mild High Strong No
# 計算熵
def calc_ent(datasets):
    data_length = len(datasets)
    label_count = {}
    for i in range(data_length):
        label = datasets[i][-1]
        if label not in label_count:
            label_count[label] = 0
        label_count[label] += 1
    ent = -sum([(p / data_length) * log(p / data_length, 2)
                for p in label_count.values()])
    return ent
calc_ent(dataset['PlayTennis'].tolist())
0.9402859586706309
# 根據某一特徵劃分數據
def split_dataframe(data, col):
        
    unique_values = data[col].unique()    
    result_dict = {elem : pd.DataFrame for elem in unique_values}    
    
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]    
    return result_dict
data_split = split_dataframe(dataset, 'Outlook')
for item, value in data_split.items():
    print(item, value)
Sunny    Outlook Temperature Humidity    Wind PlayTennis
0    Sunny         Hot     High    Weak         No
1    Sunny         Hot     High  Strong         No
7    Sunny        Mild     High    Weak         No
8    Sunny        Cool   Normal    Weak        Yes
10   Sunny        Mild   Normal  Strong        Yes
Overcast      Outlook Temperature Humidity    Wind PlayTennis
2   Overcast         Hot     High    Weak        Yes
6   Overcast        Cool   Normal  Strong        Yes
11  Overcast        Mild     High  Strong        Yes
12  Overcast         Hot   Normal    Weak        Yes
Rain    Outlook Temperature Humidity    Wind PlayTennis
3     Rain        Mild     High    Weak        Yes
4     Rain        Cool   Normal    Weak        Yes
5     Rain        Cool   Normal  Strong         No
9     Rain        Mild   Normal    Weak        Yes
13    Rain        Mild     High  Strong         No
# 選取信息增益最大的特徵
def choose_best_col(data, label):
    
    entropy_D = calc_ent(data[label].tolist()) # 劃分前的信息熵
    cols = [col for col in data.columns if col not in [label]]

    # 初始化
    max_value, best_col = -999, None
    max_splited = None
    # 根據不同的特徵拆分數據
    for col in cols:
        splited_set = split_dataframe(data, col)
        entropy_DA = 0

        for subset_col, subset in splited_set.items():
            entropy_Di = calc_ent(subset[label].tolist()) # 劃分後的信息熵
            entropy_DA += len(subset)/len(data) * entropy_Di # 求經驗條件熵
        
        info_gain = entropy_D - entropy_DA # 求信息增益
        
        if info_gain > max_value:
            max_value, best_col = info_gain, col
            max_splited = splited_set
    return max_value, best_col, max_splited
choose_best_col(dataset, 'PlayTennis')
(0.2467498197744391,
 'Outlook',
 {'Sunny':    Outlook Temperature Humidity    Wind PlayTennis
  0    Sunny         Hot     High    Weak         No
  1    Sunny         Hot     High  Strong         No
  7    Sunny        Mild     High    Weak         No
  8    Sunny        Cool   Normal    Weak        Yes
  10   Sunny        Mild   Normal  Strong        Yes,
  'Overcast':      Outlook Temperature Humidity    Wind PlayTennis
  2   Overcast         Hot     High    Weak        Yes
  6   Overcast        Cool   Normal  Strong        Yes
  11  Overcast        Mild     High  Strong        Yes
  12  Overcast         Hot   Normal    Weak        Yes,
  'Rain':    Outlook Temperature Humidity    Wind PlayTennis
  3     Rain        Mild     High    Weak        Yes
  4     Rain        Cool   Normal    Weak        Yes
  5     Rain        Cool   Normal  Strong         No
  9     Rain        Mild   Normal    Weak        Yes
  13    Rain        Mild     High  Strong         No})
# ID3算法實現
class ID3Tree:
    # 定義一個節點類
    class Node:
        def __init__(self, name):
            self.name = name
            self.connections = {}

        def connect(self, label, node):
            self.connections[label] = node
            
    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")
    
    # 打印樹的方法
    def print_tree(self, node, tabs):
        print(tabs + node.name)
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")
    
    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)
    
    # 構建樹
    def construct(self, parent_node, parent_connection_label, input_data, columns):
        max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label)
        
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)
        
        new_columns = [col for col in columns if col != best_col]
        
        # 遞歸構建決策樹
        for splited_value, splited_data in max_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)
treeId3 = ID3Tree(dataset, 'PlayTennis')
treeId3.construct_tree()
treeId3.print_tree(treeId3.root, "")
Root
	()
		Outlook
			(Sunny)
				Humidity
					(High)
						Temperature
							(Hot)
								Wind
									(Weak)
										No
									(Strong)
										No
							(Mild)
								Wind
									(Weak)
										No
					(Normal)
						Temperature
							(Cool)
								Wind
									(Weak)
										Yes
							(Mild)
								Wind
									(Strong)
										Yes
			(Overcast)
				Temperature
					(Hot)
						Humidity
							(High)
								Wind
									(Weak)
										Yes
							(Normal)
								Wind
									(Weak)
										Yes
					(Cool)
						Humidity
							(Normal)
								Wind
									(Strong)
										Yes
					(Mild)
						Humidity
							(High)
								Wind
									(Strong)
										Yes
			(Rain)
				Wind
					(Weak)
						Temperature
							(Mild)
								Humidity
									(High)
										Yes
									(Normal)
										Yes
							(Cool)
								Humidity
									(Normal)
										Yes
					(Strong)
						Temperature
							(Cool)
								Humidity
									(Normal)
										No
							(Mild)
								Humidity
									(High)
										No

Cart

import numpy as np
import pandas as pd
from math import log
import matplotlib.pyplot as plt
%matplotlib inline
dataset = pd.read_csv('data.csv')
dataset = dataset.iloc[:, 1:]
dataset
Outlook Temperature Humidity Wind PlayTennis
0 Sunny Hot High Weak No
1 Sunny Hot High Strong No
2 Overcast Hot High Weak Yes
3 Rain Mild High Weak Yes
4 Rain Cool Normal Weak Yes
5 Rain Cool Normal Strong No
6 Overcast Cool Normal Strong Yes
7 Sunny Mild High Weak No
8 Sunny Cool Normal Weak Yes
9 Rain Mild Normal Weak Yes
10 Sunny Mild Normal Strong Yes
11 Overcast Mild High Strong Yes
12 Overcast Hot Normal Weak Yes
13 Rain Mild High Strong No
# 計算Gini指數
def gini(data):
    probs = [data.count(i) / len(data) for i in set(data)]
    gini = sum([p * (1 - p) for p in probs])
    return gini
gini(dataset['PlayTennis'].tolist())
0.4591836734693877
# 根據某一特徵劃分數據
def split_dataframe(data, col):

    unique_values = data[col].unique()
    result_dict = {elem: pd.DataFrame for elem in unique_values}

    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]
    return result_dict
split_dataframe(dataset, 'Temperature')
{'Hot':      Outlook Temperature Humidity    Wind PlayTennis
 0      Sunny         Hot     High    Weak         No
 1      Sunny         Hot     High  Strong         No
 2   Overcast         Hot     High    Weak        Yes
 12  Overcast         Hot   Normal    Weak        Yes,
 'Mild':      Outlook Temperature Humidity    Wind PlayTennis
 3       Rain        Mild     High    Weak        Yes
 7      Sunny        Mild     High    Weak         No
 9       Rain        Mild   Normal    Weak        Yes
 10     Sunny        Mild   Normal  Strong        Yes
 11  Overcast        Mild     High  Strong        Yes
 13      Rain        Mild     High  Strong         No,
 'Cool':     Outlook Temperature Humidity    Wind PlayTennis
 4      Rain        Cool   Normal    Weak        Yes
 5      Rain        Cool   Normal  Strong         No
 6  Overcast        Cool   Normal  Strong        Yes
 8     Sunny        Cool   Normal    Weak        Yes}
# 選取Gini指數最小的特徵
def choose_best_col(data, label):

    gini_D = gini(df[label].tolist())  # 劃分前的Gini指數
    cols = [col for col in data.columns if col not in [label]]

    # 初始化
    min_value, best_col = 999, None
    min_splited = None

    # 根據不同的特徵拆分數據
    for col in cols:
        splited_set = split_dataframe(data, col)
        gini_DA = 0
        for subset_col, subset in splited_set.items():

            gini_Di = gini(subset[label].tolist())  # 劃分後的Gini指數
            gini_DA += len(subset) / len(data) * gini_Di  # 計算當前特徵的Gini指數

        if gini_DA < min_value:
            min_value, best_col = gini_DA, col
            min_splited = splited_set
    return min_value, best_col, min_splited
choose_best_col(dataset, 'PlayTennis')
(0.34285714285714286,
 'Outlook',
 {'Sunny':    Outlook Temperature Humidity    Wind PlayTennis
  0    Sunny         Hot     High    Weak         No
  1    Sunny         Hot     High  Strong         No
  7    Sunny        Mild     High    Weak         No
  8    Sunny        Cool   Normal    Weak        Yes
  10   Sunny        Mild   Normal  Strong        Yes,
  'Overcast':      Outlook Temperature Humidity    Wind PlayTennis
  2   Overcast         Hot     High    Weak        Yes
  6   Overcast        Cool   Normal  Strong        Yes
  11  Overcast        Mild     High  Strong        Yes
  12  Overcast         Hot   Normal    Weak        Yes,
  'Rain':    Outlook Temperature Humidity    Wind PlayTennis
  3     Rain        Mild     High    Weak        Yes
  4     Rain        Cool   Normal    Weak        Yes
  5     Rain        Cool   Normal  Strong         No
  9     Rain        Mild   Normal    Weak        Yes
  13    Rain        Mild     High  Strong         No})
# Cart算法實現
class CartTree:
    # 定義一個節點類
    class Node:
        def __init__(self, name):
            self.name = name
            self.connections = {}

        def connect(self, label, node):
            self.connections[label] = node

    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")

    # 打印樹的方法
    def print_tree(self, node, tabs):
        print(tabs + node.name)
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")

    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)

    # 構建樹
    def construct(self, parent_node, parent_connection_label, input_data,
                  columns):
        min_value, best_col, min_splited = choose_best_col(
            input_data[columns], self.label)
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col]
        # 遞歸構建樹
        for splited_value, splited_data in min_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)
treeCart = CartTree(dataset, 'PlayTennis')
treeCart.construct_tree()
treeCart.print_tree(treeCart.root, "")
Root
	()
		Outlook
			(Sunny)
				Humidity
					(High)
						Temperature
							(Hot)
								Wind
									(Weak)
										No
									(Strong)
										No
							(Mild)
								Wind
									(Weak)
										No
					(Normal)
						Temperature
							(Cool)
								Wind
									(Weak)
										Yes
							(Mild)
								Wind
									(Strong)
										Yes
			(Overcast)
				Temperature
					(Hot)
						Humidity
							(High)
								Wind
									(Weak)
										Yes
							(Normal)
								Wind
									(Weak)
										Yes
					(Cool)
						Humidity
							(Normal)
								Wind
									(Strong)
										Yes
					(Mild)
						Humidity
							(High)
								Wind
									(Strong)
										Yes
			(Rain)
				Wind
					(Weak)
						Temperature
							(Mild)
								Humidity
									(High)
										Yes
									(Normal)
										Yes
							(Cool)
								Humidity
									(Normal)
										Yes
					(Strong)
						Temperature
							(Cool)
								Humidity
									(Normal)
										No
							(Mild)
								Humidity
									(High)
										No

Reference:

  1. https://github.com/NLP-LOVE/ML-NLP/blob/master/Machine Learning/3.Desition Tree/Desition Tree.md
  2. https://cuijiahua.com/blog/2017/11/ml_2_decision_tree_1.html
  3. https://cuijiahua.com/blog/2017/11/ml_3_decision_tree_2.html
  4. https://mp.weixin.qq.com/s/6ixsCP8dvNYfqhQYUbnNHw
  5. https://mp.weixin.qq.com/s/jdUQIPM2AhAh7rzl1DPgIQ
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章