Tensorflow 2.1 遷移學習 基於VGG

環境

Tensorflow 2.1

準備工作

下載VGG 的權重可以自動下載也可以離線下載。
下載要訓練的圖片。這個裏圖片包含五種類型的花(‘daisy’,‘dandelion’,‘roses’,‘sunflowers’,‘tulips’)

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
然後解壓放在你的項目地下這個目錄裏 flower_photos

簡要說明

基於VGG的遷移學習, VGG 的權重不訓練了,因爲已經訓練好了。
但是要去掉全連接層,加上我們的全連接層就好。我們只有簡單訓練一下我全連接層就可以了。

訓練集與驗證集的結果

驗證集上準確率 80%左右

87/290 [============================>.] - ETA: 0s - loss: 0.4844 - categorical_accuracy: 0.8358
288/290 [============================>.] - ETA: 0s - loss: 0.4836 - categorical_accuracy: 0.8364
289/290 [============================>.] - ETA: 0s - loss: 0.4831 - categorical_accuracy: 0.8366save_weight 36 0.5442695867802415

290/290 [==============================] - 35s 119ms/step - loss: 0.4835 - categorical_accuracy: 0.8365 - val_loss: 0.5443 - val_categorical_accuracy: 0.8071

訓練的完整代碼

from tensorflow.keras.applications import VGG16
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import tensorflow.keras.preprocessing.image as image
import os as os

vgg16=VGG16(input_shape = (224,224,3),  include_top=False)

best_model =vgg16

l_layer=len(best_model.layers)

new_model=keras.Sequential(best_model)
for i in range(l_layer-1):
    best_model.layers[i].trainable = False

new_output=keras.layers.Dense(5,activation=tf.nn.softmax,kernel_initializer=tf.initializers.Constant(0.001))
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
new_model.add(global_average_layer)
new_model.add(new_output)

new_model.compile(optimizer=keras.optimizers.Adam(),
                   loss=keras.losses.categorical_crossentropy,
                   # metrics=['accuracy'])
                   metrics=[keras.metrics.categorical_accuracy])

new_model.summary()


#雛菊,蒲公英, 鬱金香
label_names={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
label_key=['daisy','dandelion','roses','sunflowers','tulips']

train_datagen = image.ImageDataGenerator(
    rescale=1 / 255,
    rotation_range=40,  # 角度值,0-180.表示圖像隨機旋轉的角度範圍
    width_shift_range=0.2,  # 平移比例,下同
    height_shift_range=0.2,
    shear_range=0.2,  # 隨機錯切變換角度
    zoom_range=0.2,  # 隨即縮放比例
    horizontal_flip=True,  # 隨機將一半圖像水平翻轉
    validation_split=0.2,
    fill_mode='nearest'  # 填充新創建像素的方法
)

IMG_SIZE = 224
BATCH_SIZE = 32
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
pic_folder = './flower_photos'

train_generator = train_datagen.flow_from_directory(
    directory=pic_folder,
    target_size=IMG_SHAPE[:-1],
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=10,
    subset='training',
    shuffle=True)

validation_generator = train_datagen.flow_from_directory(
    directory=pic_folder,
    target_size=IMG_SHAPE[:-1],
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=10,
    subset='validation',
    shuffle=True)

current_max_loss = 9999
weight_file='./weightsf/model.h5'

if os.path.isfile(weight_file):
    print('load weight')
    new_model.load_weights(weight_file)

def save_weight(epoch, logs):
    global current_max_loss
    if (logs['val_loss'] is not None and logs['val_loss'] < current_max_loss):
        current_max_loss = logs['val_loss']
        print('save_weight', epoch, current_max_loss)
        new_model.save_weights(weight_file)

batch_print_callback = keras.callbacks.LambdaCallback(
    on_epoch_end=save_weight
)

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=4, monitor='val_loss'),
    batch_print_callback,
    # keras.callbacks.ModelCheckpoint('./weights/model.h5', save_best_only=True),
    tf.keras.callbacks.TensorBoard(log_dir='logsf')
]

history = new_model.fit_generator(train_generator, steps_per_epoch=290, epochs=40, callbacks=callbacks,
                                   validation_data=validation_generator, validation_steps=70)
print(history)

def show_result(history):
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.plot(history.history['categorical_accuracy'])
    plt.plot(history.history['val_categorical_accuracy'])
    plt.legend(['loss', 'val_loss', 'categorical_accuracy', 'val_categorical_accuracy'],
               loc='upper left')
    plt.show()
    print(history)

show_result(history)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章