一個包含數據輸入和預處理流程的使用數據集進行訓練和測試的完整例子
import tensorflow as tf
train_files = tf.train.match_filenames_once("path/to/train-file-*")
test_files = tf.train.match_filenames_once("path/to/test-file-*")
# 定義parser方法從TFRecord中解析數據
def parser(record):
features = tf.parse_single_example(
record,
features = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
}
)
# 從原始圖像數據解析出像素矩陣, 並根據圖像尺寸還原圖像
decode_image = tf.decode_raw(features['image'], tf.uint8)
decode_image.set_shape([features['height'], features['width'], features['channels']])
label = features['label']
return decode_image, label
image_size = 299
batch_size = 100
shuffle_buffer = 10000
# 定義讀取訓練數據的數據集
dataset = tf.data.TFRecordDataset(train_files)
dataset = dataset.map(parser)
# 預處理
dataset = dataset.map(lambda image, label : (preprocess_for_train(image, image_size, image_size, None), label))
dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)
NUM_EPOCHS = 10
dataset = dataset.repeat(NUM_EPOCHS)
# 定義數據集迭代器
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
# 定義神經網絡的結構以及優化過程
learning_rate = 0.01
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
# 定義測試用的Dataset
test_dataset = tf.data.TFRecordDataset(test_files)
test_dataset = test_dataset.map(parser).map(lambda image, label : (tf.image.resize_iamges(image, [image_size, image_size]), label))
test_dataset = test_dataset.batch(batch_size)
# 定義測試數據上的迭代器
test_iterator = test_dataset.make_initializable_iterator()
test_image_batch, test_label_batch = test_iterator.get_next()
# 定義預測結果爲logit值最大的優化過程
test_logit = inference(test_image_batch)
predictions = tf.argmax(test_logit, axis = -1, output_type = tf.int32)
# 聲明會話並運行神經網絡的優化過程
with tf.Session() as sess:
# 初始化變量
sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
# 初始化訓練數據的迭代器
sess.run(iterator.initializer)
# 循環進行訓練,直到數據完成輸入、拋出OutOfRangeError錯誤
while True:
try:
sess.run(train_step)
except tf.errors.OutOfRangeError:
break
# 初始化測試數據的迭代器
sess.run(test_iterator.initializer)
# 獲取預測結果
test_results = []
test_labels = []
while True:
try:
pred, label = sess.run([predictions, test_label_batch])
test_results.extend(pred)
test_labels.extend(label)
except tf.errors.OutOfRangeError:
break
# 計算準確率
correct = [float(y == y_) for (y, y_) in zip (test_results, test_labels)]
accurcy = sum(coorect) / len(correct)
print("Test accuracy is: ", accuracy)