Tensorflow加載預訓練模型的特殊操作 1 只加載部分參數 2 從兩個預訓練模型中加載不同部分參數 3 從參數名稱不一致的模型中加載參數

最近看到一個巨牛的人工智能教程,分享一下給大家。教程不僅是零基礎,通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。平時碎片時間可以當小說看,【點這裏可以去膜拜一下大神的“小說”】

在前面的文章【Tensorflow加載預訓練模型和保存模型】中介紹瞭如何保存訓練好的模型,已經將預訓練好的模型參數加載到當前網絡。這些屬於常規操作,即預訓練的模型與當前網絡結構的命名完全一致。

本文介紹一些不常規的操作:

  1. 如何只加載部分參數?
  2. 如何從兩個模型中加載不同部分參數?
  3. 當預訓練的模型的命名與當前定義的網絡中的參數命名不一致時該怎麼辦?

1 只加載部分參數

舉個例子,對已有的網絡結構做了細微修改,例如只改了幾層卷積通道數。如果從頭訓練顯然沒有finetune收斂速度快,但是模型又沒法全部加載。此時,只需將未修改部分參數加載到當前網絡即可。假設修改過的卷積層名稱包含`conv_``,示例代碼如下:

import tensorflow as tf
def restore(sess, ckpt_path):
    vars = tf.trainable_variables()
    vars = [v for v vars if not "conv_1" in v.name]
    saver = tf.train.Saver(var_list=vars)
    saver.restore(sess, ckpt_path)

2 從兩個預訓練模型中加載不同部分參數

如果需要從兩個不同的預訓練模型中加載不同部分參數,例如,網絡中的前半部分用一個預訓練模型參數,後半部分用另一個預訓練模型中的參數,示例代碼如下:

import tensorflow as tf
def restore(sess, ckpt_path):
    vars = tf.trainable_variables()
    model_1_vars = [v for v vars if "model_1" in v.name]
    model_2_vars = [v for v vars if "model_2" in v.name]
    saver_1 = tf.train.Saver(var_list=model_1_vars)
    saver_2 = tf.train.Saver(var_list=model_2_vars)
    saver_1 .restore(sess, ckpt_path)
    saver_2 .restore(sess, ckpt_path)

3 從參數名稱不一致的模型中加載參數

舉個例子,例如,預訓練的模型所有的參數有個前綴name_1,現在定義的網絡結構中的參數以name_2作爲前綴。那麼使用如下示例代碼即可加載:

import tensorflow as tf
def restore(sess, ckpt_path):
    vars = tf.trainable_variables()
    vars_dict = dict()
    for v in vars:
        key = v.name.split(':')[0]
        if key.startswith("name_2/"):
            key = key.replace("name_2/", "name_1/")
        vars_dict[key] = v
    saver =tf.train.Saver(var_list=vars_dict)
    saver.restore(sess, ckpt_path)

注意: 使用上面代碼時,要確保參數的shape一致,否則會無法加載參數。

如果不知道預訓練的ckpt中參數名稱,可以使用如下代碼打印:

for name, shape in tf.train.list_variables(ckpt_path):
    print(name)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章