1、首先打開cifar_train.py 找到最後
if __name__ == '__main__':
tf.app.run()
這個代碼是讓所有的參數生效類似
tf.app.flags.DEFINE_string()
2、開始執行main()函數
def main(argv = none):
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(Flags.train_dir)
train()
以上代碼好理解,就是首先嚐試下載cifar10的數據文件,如果培訓目錄存在,那麼刪除,創建一個新的培訓dir,可以擴展的學習點是:關於tf.gfile的學習 https://blog.csdn.net/pursuit_zhangyu/article/details/80557958
3、研究train()函數
def train():
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
with tf.device("/cpu:0"):
images,labels = cifar10.distorted_imputs()
logits = cifar10.inference(images)
loss = cifar10.loss(logits,labels)
train_op = cifar10.train(loss,global_step)
#此處省略記錄信息的類
with tf.train.MonitoredTrainingSession(
checkpoint _dir = FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps]=),tf.train.NanTensorHook(loss),_LoggerHook()],
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
) as mon_sess: