利用Tensorflow2.0、softmax實現手寫字體識別

參考鏈接
項目採用Jupyter notebook編寫

import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np
from sklearn.preprocessing import StandardScaler

print(tf.__version__)
2.0.0
##加載數據     60000條訓練集   10000條測試集  
(x_train_all, y_train_all), (x_test, y_test) = mnist.load_data()  #此處會去官網加載數據,可能比較慢
print(type(x_train_all))
<class 'numpy.ndarray'>
#print((x_train.shape),(x_test.shape))   #(60000, 28, 28) (10000, 28, 28)
#數據歸一化
scaler = StandardScaler()
scaled_x_train_all = scaler.fit_transform(x_train_all.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
scaled_x_test = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)

#劃分驗證集和訓練集
scaled_x_train,scaled_x_valid = scaled_x_train_all[5000:],scaled_x_train_all[:5000]
y_train,y_valid = y_train_all[5000:],y_train_all[:5000]
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense

#採用函數式API
a = Input(shape=(784,))   #單條數據維度,不包括數據總數
b = Dense(10,activation='softmax')(a)
model = Model(a,b)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense (Dense)                (None, 10)                7850      
=================================================================
Total params: 7,850
Trainable params: 7,850
Non-trainable params: 0
_________________________________________________________________
model.compile(loss="sparse_categorical_crossentropy",optimizer = "sgd",metrics=["accuracy"])
model.fit(scaled_x_train.reshape(-1,784),y_train,epochs=20,validation_data=(scaled_x_valid.reshape(-1,784),y_valid))
Train on 55000 samples, validate on 5000 samples
Epoch 1/20
55000/55000 [==============================] - 3s 63us/sample - loss: 0.4423 - accuracy: 0.8708 - val_loss: 0.3158 - val_accuracy: 0.9088
Epoch 2/20
55000/55000 [==============================] - 3s 46us/sample - loss: 0.3227 - accuracy: 0.9079 - val_loss: 0.2928 - val_accuracy: 0.9194
Epoch 3/20
55000/55000 [==============================] - 2s 45us/sample - loss: 0.3039 - accuracy: 0.9136 - val_loss: 0.2796 - val_accuracy: 0.9230
Epoch 4/20
55000/55000 [==============================] - 3s 47us/sample - loss: 0.2939 - accuracy: 0.9171 - val_loss: 0.2732 - val_accuracy: 0.9254
Epoch 5/20
55000/55000 [==============================] - 3s 46us/sample - loss: 0.2871 - accuracy: 0.9201 - val_loss: 0.2722 - val_accuracy: 0.9252
Epoch 6/20
55000/55000 [==============================] - 3s 49us/sample - loss: 0.2819 - accuracy: 0.9218 - val_loss: 0.2711 - val_accuracy: 0.9256
Epoch 7/20
55000/55000 [==============================] - 3s 49us/sample - loss: 0.2781 - accuracy: 0.9219 - val_loss: 0.2707 - val_accuracy: 0.9276
Epoch 8/20
55000/55000 [==============================] - 3s 62us/sample - loss: 0.2751 - accuracy: 0.9233 - val_loss: 0.2692 - val_accuracy: 0.9262
Epoch 9/20
55000/55000 [==============================] - 3s 54us/sample - loss: 0.2724 - accuracy: 0.9243 - val_loss: 0.2650 - val_accuracy: 0.9310
Epoch 10/20
55000/55000 [==============================] - 3s 63us/sample - loss: 0.2704 - accuracy: 0.9247 - val_loss: 0.2684 - val_accuracy: 0.9272
Epoch 11/20
55000/55000 [==============================] - 3s 57us/sample - loss: 0.2683 - accuracy: 0.9250 - val_loss: 0.2645 - val_accuracy: 0.9276
Epoch 12/20
55000/55000 [==============================] - 3s 57us/sample - loss: 0.2666 - accuracy: 0.9259 - val_loss: 0.2620 - val_accuracy: 0.9296
Epoch 13/20
55000/55000 [==============================] - 3s 52us/sample - loss: 0.2653 - accuracy: 0.9264 - val_loss: 0.2621 - val_accuracy: 0.9308
Epoch 14/20
55000/55000 [==============================] - 3s 53us/sample - loss: 0.2637 - accuracy: 0.9268 - val_loss: 0.2638 - val_accuracy: 0.9286
Epoch 15/20
55000/55000 [==============================] - 2s 44us/sample - loss: 0.2622 - accuracy: 0.9274 - val_loss: 0.2639 - val_accuracy: 0.9294
Epoch 16/20
55000/55000 [==============================] - 4s 64us/sample - loss: 0.2607 - accuracy: 0.9277 - val_loss: 0.2629 - val_accuracy: 0.9292
Epoch 17/20
55000/55000 [==============================] - 3s 50us/sample - loss: 0.2600 - accuracy: 0.9279 - val_loss: 0.2632 - val_accuracy: 0.9298
Epoch 18/20
55000/55000 [==============================] - 2s 42us/sample - loss: 0.2589 - accuracy: 0.9281 - val_loss: 0.2621 - val_accuracy: 0.9306
Epoch 19/20
55000/55000 [==============================] - 2s 41us/sample - loss: 0.2583 - accuracy: 0.9282 - val_loss: 0.2631 - val_accuracy: 0.9282
Epoch 20/20
55000/55000 [==============================] - 2s 44us/sample - loss: 0.2572 - accuracy: 0.9290 - val_loss: 0.2626 - val_accuracy: 0.9310





<tensorflow.python.keras.callbacks.History at 0x258002330c8>
#測試集評估
model.evaluate(scaled_x_test.reshape(-1,784),y_test,verbose=0)  #verbose是否打印相關信息
[0.2692306630671024, 0.926]
#隨機選中圖片測試
img_random = scaled_x_test[np.random.randint(0,len(scaled_x_test))]
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(img_random)
plt.show()

在這裏插入圖片描述

#模型預測
prob = model.predict(img_random.reshape(-1,784))
print(np.argmax(prob))
1
發佈了43 篇原創文章 · 獲贊 9 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章