CART 算法手寫數字識別

from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns # 數據可視化的包

加載數據

digits = load_digits()
data = digits.data

查看數據集大小

data.shape

數據集介紹

1797個樣本,每個樣本包括88像素的圖像和一個[0, 9]整數的標籤。

array矩陣類型數據,保存8
8的圖像,裏面的元素是float64類型,共有1797張圖片
用於顯示圖片。

獲取第一張圖片的像素數

print(digits.images[0])

將25%的數據作爲測試集,其餘作爲訓練集

train_x, test_x, train_y, test_y = train_test_split(data, digits.target, test_size=0.25, random_state=33)

採用Z-Score規範化

ss = preprocessing.StandardScaler()
train_ss_x = ss.fit_transform(train_x)
test_ss_x = ss.transform(test_x)

CART 算法簡單介紹

Classification And Regression Tree,即分類迴歸樹算法,簡稱CART算法,它是決策樹的一種實現,通常決策樹主要有三種實現,分別是ID3算法,CART算法和C4.5算法。

CART 算法採用 Gini係數作爲標準進行特徵分割。

決策樹的算法原理大家不理解屬於正常,老師還沒有講到。

有興趣瞭解的同學可以看一下鏈接:

https://zhuanlan.zhihu.com/p/30059442

https://zhuanlan.zhihu.com/p/104462031

#訓練一個DecisionTree分類器

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0,splitter='best',criterion='gini') # sklearn默認使用基尼Gini係數
clf.fit(train_ss_x,train_y)

predict_y = clf.predict(test_ss_x)
print('CART算法準確率: %0.4lf' % accuracy_score(test_y, predict_y))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章