代碼
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn import tree
from sklearn.externals.six import StringIO
# pip install pydotplus
# pip install graphviz
import pydotplus
# Graphviz瞎子地址:http://www.graphviz.org/download/
import os
os.environ["PATH"] += os.pathsep + 'D:/program files (x86)/Graphviz2.38/bin'
def loadData():
"""
加載文件,生成特徵集和目標值集
:return:
"""
# 加載文件
with open('lenses.txt') as fr:
# 處理文件,去掉每行兩頭的空白符,以\t分隔每個數據
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
# 提取每組數據的類別,保存在列表裏
lenses_targt = []
for each in lenses:
# 存儲Label到lenses_targt中
lenses_targt.append([each[-1]])
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
# 保存lenses數據的字典,用於生成pandas
lenses_dict = {}
# 提取信息,生成字典
for each_label in lensesLabels:
# 保存lenses數據的臨時列表
lenses_list = []
for each in lenses:
# index方法用於從列表中找出某個值第一個匹配項的索引位置
lenses_list.append(each[lensesLabels.index(each_label)])
lenses_dict[each_label] = lenses_list
# 生成pandas.DataFrame用於對象的創建
lenses_pd = pd.DataFrame(lenses_dict)
print(lenses_targt)
print(lenses_pd)
return lenses_pd, lenses_targt
def dataEncoder(data_pd):
le = LabelEncoder()
# 爲每一列序列化
for col in data_pd.columns:
# fit_transform()幹了兩件事:fit找到數據轉換規則,並將數據標準化
# transform()直接把轉換規則拿來用,需要先進行fit
# transform函數是一定可以替換爲fit_transform函數的,fit_transform函數不能替換爲transform函數
data_pd[col] = le.fit_transform(data_pd[col])
# 打印歸一化的結果
print(data_pd)
def createTree(data_pd, labels):
# 創建DecisionTreeClassifier()類
clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=4)
# 使用數據構造決策樹
# fit(X,y):Build a decision tree classifier from the training set(X,y)
# 所有的sklearn的API必須先fit
clf = clf.fit(data_pd.values.tolist(), labels)
return clf
def exportTree(clf, feature_names):
# 保存樹
with open("lenses.dot", 'w') as f:
tree.export_graphviz(clf, out_file=f)
# 打印樹
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
feature_names=feature_names,
class_names=clf.classes_,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("tree.pdf")
def main():
# 生成數據集和目標值集
data_pd, targts = loadData()
# 數據編碼,序列化
dataEncoder(data_pd)
# 生成樹
tree = createTree(data_pd, targts)
# 保存樹、打印樹
exportTree(tree, data_pd.keys())
# 預測
print(tree.predict([[1, 1, 1, 0]]))
if __name__ == '__main__':
main()
運行結果
[['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses']]
age astigmatic prescript tearRate
0 young no myope reduced
1 young no myope normal
2 young yes myope reduced
3 young yes myope normal
4 young no hyper reduced
5 young no hyper normal
6 young yes hyper reduced
7 young yes hyper normal
8 pre no myope reduced
9 pre no myope normal
10 pre yes myope reduced
11 pre yes myope normal
12 pre no hyper reduced
13 pre no hyper normal
14 pre yes hyper reduced
15 pre yes hyper normal
16 presbyopic no myope reduced
17 presbyopic no myope normal
18 presbyopic yes myope reduced
19 presbyopic yes myope normal
20 presbyopic no hyper reduced
21 presbyopic no hyper normal
22 presbyopic yes hyper reduced
23 presbyopic yes hyper normal
age astigmatic prescript tearRate
0 2 0 1 1
1 2 0 1 0
2 2 1 1 1
3 2 1 1 0
4 2 0 0 1
5 2 0 0 0
6 2 1 0 1
7 2 1 0 0
8 0 0 1 1
9 0 0 1 0
10 0 1 1 1
11 0 1 1 0
12 0 0 0 1
13 0 0 0 0
14 0 1 0 1
15 0 1 0 0
16 1 0 1 1
17 1 0 1 0
18 1 1 1 1
19 1 1 1 0
20 1 0 0 1
21 1 0 0 0
22 1 1 0 1
23 1 1 0 0
['hard']
Process finished with exit code 0
lenses.dot
digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 0.5\nentropy = 1.326\nsamples = 24\nvalue = [4, 15, 5]"] ;
1 [label="X[1] <= 0.5\nentropy = 1.555\nsamples = 12\nvalue = [4, 3, 5]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[2] <= 0.5\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]"] ;
1 -> 2 ;
3 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
2 -> 3 ;
4 [label="X[0] <= 0.5\nentropy = 0.918\nsamples = 3\nvalue = [0, 1, 2]"] ;
2 -> 4 ;
5 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
4 -> 5 ;
6 [label="entropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]"] ;
4 -> 6 ;
7 [label="X[2] <= 0.5\nentropy = 0.918\nsamples = 6\nvalue = [4, 2, 0]"] ;
1 -> 7 ;
8 [label="X[0] <= 1.5\nentropy = 0.918\nsamples = 3\nvalue = [1, 2, 0]"] ;
7 -> 8 ;
9 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0, 0]"] ;
8 -> 10 ;
11 [label="entropy = 0.0\nsamples = 3\nvalue = [3, 0, 0]"] ;
7 -> 11 ;
12 [label="entropy = 0.0\nsamples = 12\nvalue = [0, 12, 0]"] ;
0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
}
樹圖如下