Tensorflow查看網絡(inspect)、凍結變量(freeze)和遷移訓練(finetune)

Tensorflow查看網絡、凍結變量和遷移訓練

(Inspect network structure, freeze graph variables, and finetune/transfer learning in Tensorflow)

 

1.    查看網絡結構和參數

python
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py
--file_name=model.ckpt-1562770
--tensor_name=unit_1_2/sub1/conv1/DW

源碼中的inspect_checkpoint.py可以看ckpt文件中的層和某層的權重值

如果只有--file_name就只顯示層,如果還有--tensor_name就能顯示那一層的權重

 

2.    只訓練graph中部分變量(相當於凍結了其他變量)

Tensorflow在構建graph的過程中會默認自動收集一些變量名到對應的Collection。例如TRAINABLE_VARIABLES就是所有可訓練的變量集合。

因此可以通過使用tf.get_collection,指定TRAINABLE_VARIABLES,使其僅包含我們需要重新訓練的變量,來凍結其他變量的訓練。

例子如下:

    first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
        "unit_last")
    trainable_variables = first_train_vars
    #print trainable_variables
    grads = self.optimizer.compute_gradients(self.cost, self.trainable_variables)

 

3.    更改graph後恢復訓練

根據monitored_session.py,使用MonitoredTrainingSession來開啓控制Session的時候,若指定的checkpoint路徑中有上次的存檔,則現有源碼只能嚴格按照之前訓練恢復。因此我們需要一個空的checkpoint路徑,此時MonitoredTrainingSession就會執行init_op以及init_fn。在init_fn中自己添加恢複函數,並把init_fn作爲參數加入MonitoredTrainingSession中的scaffold即可。

例子如下:

    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=['logit'])
    init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
        ckpt.model_checkpoint_path, variables_to_restore)
    def InitAssignFn(scaffold, sess):
        sess.run(init_assign_op, init_feed_dict)
    scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章