CIFAR-10 DEMO代碼閱讀與理解

1、首先打開cifar_train.py 找到最後

if __name__ == '__main__':
	tf.app.run()

這個代碼是讓所有的參數生效類似

tf.app.flags.DEFINE_string()

2、開始執行main()函數

def main(argv = none):
	cifar10.maybe_download_and_extract()
	if tf.gfile.Exists(FLAGS.train_dir):
		tf.gfile.DeleteRecursively(FLAGS.train_dir)
	tf.gfile.MakeDirs(Flags.train_dir)
	train()

以上代碼好理解,就是首先嚐試下載cifar10的數據文件,如果培訓目錄存在,那麼刪除,創建一個新的培訓dir,可以擴展的學習點是:關於tf.gfile的學習 https://blog.csdn.net/pursuit_zhangyu/article/details/80557958

3、研究train()函數

def train():
	with tf.Graph().as_default():
		global_step = tf.train.get_or_create_global_step()
		with tf.device("/cpu:0"):
			images,labels = cifar10.distorted_imputs()
		logits = cifar10.inference(images)
		loss = cifar10.loss(logits,labels)
		train_op = cifar10.train(loss,global_step)
		#此處省略記錄信息的類
		with tf.train.MonitoredTrainingSession(
			checkpoint _dir = FLAGS.train_dir,
			hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps]=),tf.train.NanTensorHook(loss),_LoggerHook()],
			config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
		) as mon_sess:
		
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章