Tensorflow查看網絡、凍結變量和遷移訓練
1. 查看網絡結構和參數
python
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py
--file_name=model.ckpt-1562770
--tensor_name=unit_1_2/sub1/conv1/DW
源碼中的inspect_checkpoint.py可以看ckpt文件中的層和某層的權重值
如果只有--file_name就只顯示層,如果還有--tensor_name就能顯示那一層的權重
2. 只訓練graph中部分變量(相當於凍結了其他變量)
Tensorflow在構建graph的過程中會默認自動收集一些變量名到對應的Collection。例如TRAINABLE_VARIABLES就是所有可訓練的變量集合。
因此可以通過使用tf.get_collection,指定TRAINABLE_VARIABLES,使其僅包含我們需要重新訓練的變量,來凍結其他變量的訓練。
例子如下:
first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"unit_last")
trainable_variables = first_train_vars
#print trainable_variables
grads = self.optimizer.compute_gradients(self.cost, self.trainable_variables)
3. 更改graph後恢復訓練
根據monitored_session.py,使用MonitoredTrainingSession來開啓控制Session的時候,若指定的checkpoint路徑中有上次的存檔,則現有源碼只能嚴格按照之前訓練恢復。因此我們需要一個空的checkpoint路徑,此時MonitoredTrainingSession就會執行init_op以及init_fn。在init_fn中自己添加恢複函數,並把init_fn作爲參數加入MonitoredTrainingSession中的scaffold即可。
例子如下:
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=['logit'])
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold, sess):
sess.run(init_assign_op, init_feed_dict)
scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)