環境
Tensorflow 2.1
準備工作
下載VGG 的權重可以自動下載也可以離線下載。
下載要訓練的圖片。這個裏圖片包含五種類型的花(‘daisy’,‘dandelion’,‘roses’,‘sunflowers’,‘tulips’)
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
然後解壓放在你的項目地下這個目錄裏 flower_photos
簡要說明
基於VGG的遷移學習, VGG 的權重不訓練了,因爲已經訓練好了。
但是要去掉全連接層,加上我們的全連接層就好。我們只有簡單訓練一下我全連接層就可以了。
訓練集與驗證集的結果
驗證集上準確率 80%左右
87/290 [============================>.] - ETA: 0s - loss: 0.4844 - categorical_accuracy: 0.8358
288/290 [============================>.] - ETA: 0s - loss: 0.4836 - categorical_accuracy: 0.8364
289/290 [============================>.] - ETA: 0s - loss: 0.4831 - categorical_accuracy: 0.8366save_weight 36 0.5442695867802415
290/290 [==============================] - 35s 119ms/step - loss: 0.4835 - categorical_accuracy: 0.8365 - val_loss: 0.5443 - val_categorical_accuracy: 0.8071
訓練的完整代碼
from tensorflow.keras.applications import VGG16
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import tensorflow.keras.preprocessing.image as image
import os as os
vgg16=VGG16(input_shape = (224,224,3), include_top=False)
best_model =vgg16
l_layer=len(best_model.layers)
new_model=keras.Sequential(best_model)
for i in range(l_layer-1):
best_model.layers[i].trainable = False
new_output=keras.layers.Dense(5,activation=tf.nn.softmax,kernel_initializer=tf.initializers.Constant(0.001))
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
new_model.add(global_average_layer)
new_model.add(new_output)
new_model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.categorical_crossentropy,
# metrics=['accuracy'])
metrics=[keras.metrics.categorical_accuracy])
new_model.summary()
#雛菊,蒲公英, 鬱金香
label_names={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
label_key=['daisy','dandelion','roses','sunflowers','tulips']
train_datagen = image.ImageDataGenerator(
rescale=1 / 255,
rotation_range=40, # 角度值,0-180.表示圖像隨機旋轉的角度範圍
width_shift_range=0.2, # 平移比例,下同
height_shift_range=0.2,
shear_range=0.2, # 隨機錯切變換角度
zoom_range=0.2, # 隨即縮放比例
horizontal_flip=True, # 隨機將一半圖像水平翻轉
validation_split=0.2,
fill_mode='nearest' # 填充新創建像素的方法
)
IMG_SIZE = 224
BATCH_SIZE = 32
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
pic_folder = './flower_photos'
train_generator = train_datagen.flow_from_directory(
directory=pic_folder,
target_size=IMG_SHAPE[:-1],
color_mode='rgb',
classes=None,
class_mode='categorical',
batch_size=10,
subset='training',
shuffle=True)
validation_generator = train_datagen.flow_from_directory(
directory=pic_folder,
target_size=IMG_SHAPE[:-1],
color_mode='rgb',
classes=None,
class_mode='categorical',
batch_size=10,
subset='validation',
shuffle=True)
current_max_loss = 9999
weight_file='./weightsf/model.h5'
if os.path.isfile(weight_file):
print('load weight')
new_model.load_weights(weight_file)
def save_weight(epoch, logs):
global current_max_loss
if (logs['val_loss'] is not None and logs['val_loss'] < current_max_loss):
current_max_loss = logs['val_loss']
print('save_weight', epoch, current_max_loss)
new_model.save_weights(weight_file)
batch_print_callback = keras.callbacks.LambdaCallback(
on_epoch_end=save_weight
)
callbacks = [
tf.keras.callbacks.EarlyStopping(patience=4, monitor='val_loss'),
batch_print_callback,
# keras.callbacks.ModelCheckpoint('./weights/model.h5', save_best_only=True),
tf.keras.callbacks.TensorBoard(log_dir='logsf')
]
history = new_model.fit_generator(train_generator, steps_per_epoch=290, epochs=40, callbacks=callbacks,
validation_data=validation_generator, validation_steps=70)
print(history)
def show_result(history):
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.legend(['loss', 'val_loss', 'categorical_accuracy', 'val_categorical_accuracy'],
loc='upper left')
plt.show()
print(history)
show_result(history)