1.vggNet簡介
vgg16是2014年由牛津大學提出的一個深度神經網絡模型,該模型在2014年的ILSVRC分類比賽中,取得了第二名的成績,而第一名當屬大名鼎鼎的googleNet,vggNet包含5種網絡類型,如下圖所示:
常見的有vgg16和vgg19。顧名思義vgg16有16層,包含13層卷積池化層和3層全連接層。而vgg19包含16層卷積池化層和3層全連接層。vggNet全部使用1x1,3x3的卷積核,而且vggNet證明了兩個3x3的卷積核可以等效爲一個5x5的卷積核,下圖示
一張5x5的圖經兩個3x3的卷積核卷積後得到一張1x1的特徵圖,等效爲一個5x5的卷積核。同時在參數量上可以發現,5x5的卷積核的參數量是5x5=25,兩個3x3的卷積核是2x3x3=18,參數量是減少了的28%,同時由於與一個5x5的卷積核卷積只需一次非線性激活,而與兩個卷積核卷積可以進行兩次非線性激活變換,非線性表徵加強了,增加了CNN對特徵的學習能力。另外1x1卷積核能實現降維,增加非線性。
2.vgg16實現遷移學習
1.數據集準備,我使用8類數據,分別是truck,tiger,flower,kittycat,guitar,houses,plane,person,數據每類訓練集500張,驗證集300張
2.vgg16預訓練權重下載,我把它放在我的百度網盤裏了,密碼fwi4
3.生成train.txt,val.txt,label.txt
create_labels_files.py
# -*-coding:utf-8-*-
import os
import os.path
def write_txt(content, filename, mode='w'):
"""保存txt數據
:param content:需要保存的數據,type->list
:param filename:文件名
:param mode:讀寫模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作爲分隔符
str_line = str_line + str(data) + " "
else:
# 每行最後一個數據用換行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line)
def get_files_list(dir):
'''
實現遍歷dir目錄下,所有文件(包含子文件夾的文件)
:param dir:指定文件夾目錄
:return:包含所有文件的列表->list
'''
# parent:父目錄, filenames:該目錄下所有文件夾,filenames:該目錄下的文件名
files_list = []
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
print("parent is: " + parent)
print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 輸出rootdir路徑下所有文件(包含子文件)信息
curr_file = parent.split(os.sep)[-1]
if curr_file == 'flower':
labels = 0
elif curr_file == 'guitar':
labels = 1
elif curr_file == 'person':
labels = 2
elif curr_file == 'houses':
labels = 3
elif curr_file == 'plane':
labels = 4
elif curr_file == 'tiger':
labels = 5
elif curr_file == 'kittycat':
labels = 6
elif curr_file == 'truck':
labels = 7
files_list.append([os.path.join(curr_file, filename), labels])
print(files_list)
return files_list
if __name__ == '__main__':
train_dir = 'dataset/train'
train_txt = 'dataset/train.txt'
train_data = get_files_list(train_dir)
write_txt(train_data, train_txt, mode='w')
val_dir = 'dataset/val'
val_txt = 'dataset/val.txt'
val_data = get_files_list(val_dir)
write_txt(val_data, val_txt, mode='w')
4.製作tf.record文件
create_tf_record.py
# -*-coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的屬性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成實數型的屬性
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def get_example_nums(tf_records_filenames):
'''
統計tf_records圖像的個數(example)個數
:param tf_records_filenames: tf_records文件路徑
:return:
'''
nums= 0
for record in tf.python_io.tf_record_iterator(tf_records_filenames):
nums += 1
return nums
def show_image(title,image):
'''
顯示圖片
:param title: 圖像標題
:param image: 圖像的數據
:return:
'''
# plt.figure("show_image")
# print(image.dtype)
plt.imshow(image)
plt.axis('on') # 關掉座標軸爲 off
plt.title(title) # 圖像題目
plt.show()
def load_labels_file(filename,labels_num=1,shuffle=False):
'''
載圖txt文件,文件中每行爲一個圖片信息,且以空格隔開:圖像路徑 標籤1 標籤2,如:test_image/1.jpg 0 2
:param filename:
:param labels_num :labels個數
:param shuffle :是否打亂順序
:return:images type->list
:return:labels type->list
'''
images=[]
labels=[]
with open(filename) as f:
lines_list=f.readlines()
if shuffle:
random.shuffle(lines_list)
for lines in lines_list:
line=lines.rstrip().split(' ')
label=[]
for i in range(labels_num):
label.append(int(line[i+1]))
images.append(line[0])
labels.append(label)
return images,labels
def read_image(filename, resize_height, resize_width,normalization=False):
'''
讀取圖片數據,默認返回的是uint8,[0,255]
:param filename:
:param resize_height:
:param resize_width:
:param normalization:是否歸一化到[0.,1.0]
:return: 返回的圖片數據
'''
bgr_image = cv2.imread(filename)
if None is bgr_image:
pass
elif len(bgr_image.shape)==2:#若是灰度圖則轉爲三通道
print("Warning:gray image",filename)
bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
print(filename)
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉爲RGB
# show_image(filename,rgb_image)
# rgb_image=Image.open(filename)
if resize_height>0 and resize_width>0:
rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
rgb_image=np.asanyarray(rgb_image)
if normalization:
# 不能寫成:rgb_image=rgb_image/255
rgb_image=rgb_image/255.0
# show_image("src resize image",image)
return rgb_image
def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=True,num_threads=1):
'''
:param images:圖像
:param labels:標籤
:param batch_size:
:param labels_nums:標籤個數
:param one_hot:是否將labels轉爲one_hot的形式
:param shuffle:是否打亂順序,一般train時shuffle=True,驗證時shuffle=False
:return:返回batch的images和labels
'''
min_after_dequeue = 200
capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大於min_after_dequeue參數值
if shuffle:
images_batch, labels_batch = tf.train.shuffle_batch([images,labels],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
num_threads=num_threads)
else:
images_batch, labels_batch = tf.train.batch([images,labels],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
if one_hot:
labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
return images_batch,labels_batch
def read_records(filename,resize_height, resize_width,type=None):
'''
解析record文件:源文件的圖像數據是RGB,uint8,[0,255],一般作爲訓練數據時,需要歸一化到[0,1]
:param filename:
:param resize_height:
:param resize_width:
:param type:選擇圖像數據的返回類型
None:默認將uint8-[0,255]轉爲float32-[0,255]
normalization:歸一化float32-[0,1]
centralization:歸一化float32-[0,1],再減均值中心化
:return:
'''
# 創建文件隊列,不限讀取的數量
filename_queue = tf.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.TFRecordReader()
# reader從文件隊列中讀入一個序列化的樣本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符號化的樣本
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數據
tf_height = features['height']
tf_width = features['width']
tf_depth = features['depth']
tf_label = tf.cast(features['label'], tf.int32)
# PS:恢復原始圖像數據,reshape的大小必須與保存之前的圖像shape一致,否則出錯
# tf_image=tf.reshape(tf_image, [-1]) # 轉換爲行向量
tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設置圖像的維度
# 恢復數據後,纔可以對圖像進行resize_images:輸入uint->輸出float32
# tf_image=tf.image.resize_images(tf_image,[224, 224])
# 存儲的圖像類型爲uint8,tensorflow訓練時數據必須是tf.float32
if type is None:
tf_image = tf.cast(tf_image, tf.float32)
elif type=='normalization':# [1]若需要歸一化請使用:
# 僅當輸入數據是uint8,纔會歸一化[0,255]
# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 歸一化
elif type=='centralization':
# 若需要歸一化,且中心化,假設均值爲0.5,請使用:
tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化
# 這裏僅僅返回圖像和標籤
# return tf_image, tf_height,tf_width,tf_depth,tf_label
return tf_image,tf_label
def create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):
'''
實現將圖像原始數據,label,長,寬等信息保存爲record文件
注意:讀取的圖像數據默認是uint8,再轉爲tf的字符串型BytesList保存,解析請需要根據需要轉換類型
:param image_dir:原始圖像的目錄
:param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑)
:param output_record_dir:保存record文件的路徑
:param resize_height:
:param resize_width:
PS:當resize_height或者resize_width=0是,不執行resize
:param shuffle:是否打亂順序
:param log:log信息打印間隔
'''
# 加載文件,僅獲取一個label
images_list, labels_list=load_labels_file(file,1,shuffle)
writer = tf.python_io.TFRecordWriter(output_record_dir)
for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
image_path=os.path.join(image_dir,images_list[i])
if not os.path.exists(image_path):
print('Err:no image',image_path)
continue
image = read_image(image_path, resize_height, resize_width)
image_raw = image.tostring()
if i%log==0 or i==len(images_list)-1:
print('------------processing:%d-th------------' % (i))
print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))
# 這裏僅保存一個label,多label適當增加"'label': _int64_feature(label)"項
label=labels[0]
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'height': _int64_feature(image.shape[0]),
'width': _int64_feature(image.shape[1]),
'depth': _int64_feature(image.shape[2]),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
def disp_records(record_file,resize_height, resize_width,show_nums=4):
'''
解析record文件,並顯示show_nums張圖片,主要用於驗證生成record文件是否成功
:param tfrecord_file: record文件路徑
:return:
'''
# 讀取record函數
tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
# 顯示前4個圖片
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(show_nums):
image,label = sess.run([tf_image,tf_label]) # 在會話中取出image和label
# image = tf_image.eval()
# 直接從record解析的image是一個向量,需要reshape顯示
# image = image.reshape([height,width,depth])
print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))
# pilimg = Image.fromarray(np.asarray(image_eval_reshape))
# pilimg.show()
show_image("image:%d"%(label),image)
coord.request_stop()
coord.join(threads)
def batch_test(record_file,resize_height, resize_width):
'''
:param record_file: record文件路徑
:param resize_height:
:param resize_width:
:return:
:PS:image_batch, label_batch一般作爲網絡的輸入
'''
# 讀取record函數
tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)
init = tf.global_variables_initializer()
with tf.Session() as sess: # 開始一個會話
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(4):
# 在會話中取出images和labels
images, labels = sess.run([image_batch, label_batch])
# 這裏僅顯示每個batch裏第一張圖片
show_image("image", images[0, :, :, :])
print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))
# 停止所有線程
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
# 參數設置
resize_height = 224 # 指定存儲圖片高度
resize_width = 224 # 指定存儲圖片寬度
shuffle=True
log=5
# 產生train.record文件
image_dir='dataset/train'
train_labels = 'dataset/train.txt' # 圖片路徑
train_record_output = 'dataset/record/train{}.tfrecords'.format(resize_height)
create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
train_nums=get_example_nums(train_record_output)
print("save train example nums={}".format(train_nums))
# 產生val.record文件
image_dir='dataset/val'
val_labels = 'dataset/val.txt' # 圖片路徑
val_record_output = 'dataset/record/val{}.tfrecords'.format(resize_height)
create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
val_nums=get_example_nums(val_record_output)
print("save val example nums={}".format(val_nums))
# 測試顯示函數
# disp_records(train_record_output,resize_height, resize_width)
batch_test(train_record_output,resize_height, resize_width)
5.訓練模型
vgg16.py
#vgg16_train_and_val
import tensorflow as tf
import numpy as np
import pdb
import os
from datetime import datetime
from create_tf_record import *
import tensorflow.contrib.slim as slim
print("Tensorflow version:{}".format(tf.__version__))
labels_nums = 8 # 類別個數
batch_size = 1 #
resize_height = 224 # 指定存儲圖片高度
resize_width = 224 # 指定存儲圖片寬度
depths = 3
data_shape = [batch_size, resize_height, resize_width, depths]
# 定義input_images爲圖片數據
input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
# 定義input_labels爲labels數據
# input_labels = tf.placeholder(dtype=tf.int32, shape=[None], name='label')
input_labels = tf.placeholder(dtype=tf.int32, shape=[None, labels_nums], name='label')
# 定義dropout的概率
keep_prob = tf.placeholder(tf.float32,name='keep_prob')
is_training = tf.placeholder(tf.bool, name='is_training')
def net_evaluation(sess,loss,accuracy,val_images_batch,val_labels_batch,val_nums):
val_max_steps = int(val_nums / batch_size)
val_losses = []
val_accs = []
for _ in range(val_max_steps):
val_x, val_y = sess.run([val_images_batch, val_labels_batch])
# print('labels:',val_y)
# val_loss = sess.run(loss, feed_dict={x: val_x, y: val_y, keep_prob: 1.0})
# val_acc = sess.run(accuracy,feed_dict={x: val_x, y: val_y, keep_prob: 1.0})
val_loss,val_acc = sess.run([loss,accuracy], feed_dict={input_images: val_x, input_labels: val_y, keep_prob:1.0, is_training: False})
val_losses.append(val_loss)
val_accs.append(val_acc)
mean_loss = np.array(val_losses, dtype=np.float32).mean()
mean_acc = np.array(val_accs, dtype=np.float32).mean()
return mean_loss, mean_acc
class Vgg16:
vgg_mean = [103.939, 116.779, 123.68]
def __init__(self, vgg16_npy_path=None,input=None, restore_from=None):
# pre-trained parameters
try:
self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
except FileNotFoundError:
print('Please download VGG16 parameters from here https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM\nOr from my Baidu Cloud: https://pan.baidu.com/s/1Spps1Wy0bvrQHH2IMkRfpg')
# self.tfx = tf.placeholder(tf.float32, [None, 224, 224, 3])
self.sess = tf.Session()
self.tfx = input
self.tfy = tf.placeholder(tf.float32, [None, 1])
# Convert RGB to BGR
red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=self.tfx * 255.0)
bgr = tf.concat(axis=3, values=[
blue - self.vgg_mean[0],
green - self.vgg_mean[1],
red - self.vgg_mean[2],
])
# pre-trained VGG layers are fixed in fine-tune
conv1_1 = self.conv_layer(bgr, "conv1_1")
conv1_2 = self.conv_layer(conv1_1, "conv1_2")
pool1 = self.max_pool(conv1_2, 'pool1')
conv2_1 = self.conv_layer(pool1, "conv2_1")
conv2_2 = self.conv_layer(conv2_1, "conv2_2")
pool2 = self.max_pool(conv2_2, 'pool2')
conv3_1 = self.conv_layer(pool2, "conv3_1")
conv3_2 = self.conv_layer(conv3_1, "conv3_2")
conv3_3 = self.conv_layer(conv3_2, "conv3_3")
pool3 = self.max_pool(conv3_3, 'pool3')
conv4_1 = self.conv_layer(pool3, "conv4_1")
conv4_2 = self.conv_layer(conv4_1, "conv4_2")
conv4_3 = self.conv_layer(conv4_2, "conv4_3")
pool4 = self.max_pool(conv4_3, 'pool4')
conv5_1 = self.conv_layer(pool4, "conv5_1")
conv5_2 = self.conv_layer(conv5_1, "conv5_2")
conv5_3 = self.conv_layer(conv5_2, "conv5_3")
pool5 = self.max_pool(conv5_3, 'pool5')
# detach original VGG fc layers and
# reconstruct your own fc layers serve for your own purpose
pool5_shape = pool5.get_shape().as_list()
nodes = pool5_shape[1] * pool5_shape[2] * pool5_shape[3]
self.flatten = tf.reshape(pool5, [-1, nodes])
self.fc6 = tf.layers.dense(self.flatten, 256, tf.nn.relu, name='fc6')
self.out = tf.layers.dense(self.fc6, labels_nums, name='out')
def max_pool(self, bottom, name):
return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
def conv_layer(self, bottom, name):
with tf.variable_scope(name): # CNN's filter is constant, NOT Variable that can be trained
conv = tf.nn.conv2d(bottom, self.data_dict[name][0], [1, 1, 1, 1], padding='SAME')
lout = tf.nn.relu(tf.nn.bias_add(conv, self.data_dict[name][1]))
return lout
def train(self, x, y):
loss, _ = self.sess.run([self.loss, self.train_op], {self.tfx: x, self.tfy: y})
return loss
def save(self, path='./model/'):
saver = tf.train.Saver()
saver.save(self.sess, path, write_meta_graph=False)
def train(train_record_file,
train_log_step,
train_param,
val_record_file,
val_log_step,
labels_nums,
data_shape,
snapshot,
snapshot_prefix):
'''
:param train_record_file: 訓練的tfrecord文件
:param train_log_step: 顯示訓練過程log信息間隔
:param train_param: train參數
:param val_record_file: 驗證的tfrecord文件
:param val_log_step: 顯示驗證過程log信息間隔
:param val_param: val參數
:param labels_nums: labels數
:param data_shape: 輸入數據shape
:param snapshot: 保存模型間隔
:param snapshot_prefix: 保存模型文件的前綴名
:return:
'''
[base_lr,max_steps]=train_param
[batch_size,resize_height,resize_width,depths]=data_shape
# 獲得訓練和測試的樣本數
train_nums=get_example_nums(train_record_file)
val_nums=get_example_nums(val_record_file)
print('train nums:%d,val nums:%d'%(train_nums,val_nums))
# 從record中讀取圖片和labels數據
# train數據,訓練數據一般要求打亂順序shuffle=True
train_images, train_labels = read_records(train_record_file, resize_height, resize_width, type='normalization')
train_images_batch, train_labels_batch = get_batch_images(train_images, train_labels,
batch_size=batch_size, labels_nums=labels_nums,
one_hot=True, shuffle=False)
# val數據,驗證數據可以不需要打亂數據
val_images, val_labels = read_records(val_record_file, resize_height, resize_width, type='normalization')
val_images_batch, val_labels_batch = get_batch_images(val_images, val_labels,
batch_size=batch_size, labels_nums=labels_nums,
one_hot=True, shuffle=False)
# Define the model:
# with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
# out, end_points = inception_v3.inception_v3(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=keep_prob, is_training=is_training)
vgg = Vgg16(vgg16_npy_path='./vgg16.npy',input=input_images)
out = vgg.out
# Specify the loss function: tf.losses定義的loss函數都會自動添加到loss函數,不需要add_loss()了
tf.losses.softmax_cross_entropy(onehot_labels=input_labels, logits=out)#添加交叉熵損失loss=1.6
# slim.losses.add_loss(my_loss)
loss = tf.losses.get_total_loss(add_regularization_losses=True)#添加正則化損失loss=2.2
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(input_labels, 1)), tf.float32))
# Specify the optimization scheme:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=base_lr)
train_op = slim.learning.create_train_op(total_loss=loss,optimizer=optimizer)
saver = tf.train.Saver()
max_acc=0.0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(max_steps+1):
batch_input_images, batch_input_labels = sess.run([train_images_batch, train_labels_batch])
_, train_loss = sess.run([train_op, loss], feed_dict={input_images:batch_input_images,
input_labels:batch_input_labels,
keep_prob:0.5, is_training:True})
# train測試(這裏僅測試訓練集的一個batch)
if i%train_log_step == 0:
train_acc = sess.run(accuracy, feed_dict={input_images:batch_input_images,
input_labels: batch_input_labels,
keep_prob:1.0, is_training: False})
print("%s: Step [%d] train Loss : %f, training accuracy : %g" % (datetime.now(), i, train_loss, train_acc))
# val測試(測試全部val數據)
if i%val_log_step == 0:
mean_loss, mean_acc=net_evaluation(sess, loss, accuracy, val_images_batch, val_labels_batch,val_nums)
print("%s: Step [%d] val Loss : %f, val accuracy : %g" % (datetime.now(), i, mean_loss, mean_acc))
# 模型保存:每迭代snapshot次或者最後一次保存模型
if (i %snapshot == 0 and i >0)or i == max_steps:
print('-----save:{}-{}'.format(snapshot_prefix,i))
saver.save(sess, snapshot_prefix, global_step=i)
# 保存val準確率最高的模型
if mean_acc>max_acc and mean_acc>0.5:
max_acc=mean_acc
path = os.path.dirname(snapshot_prefix)
best_models=os.path.join(path,'best_models_{}_{:.4f}.ckpt'.format(i,max_acc))
print('------save:{}'.format(best_models))
saver.save(sess, best_models)
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
train_record_file='dataset/record/train224.tfrecords'
val_record_file='dataset/record/val224.tfrecords'
train_log_step=100
base_lr = 0.01 # 學習率
max_steps = 200000 # 迭代次數
train_param=[base_lr,max_steps]
val_log_step=200
snapshot=2000#保存文件間隔
snapshot_prefix='./models/model.ckpt'
train(train_record_file=train_record_file,
train_log_step=train_log_step,
train_param=train_param,
val_record_file=val_record_file,
val_log_step=val_log_step,
labels_nums=labels_nums,
data_shape=data_shape,
snapshot=snapshot,
snapshot_prefix=snapshot_prefix)
3結果顯示
用實驗室服務器訓練了20萬代,在驗證集上的準確率達到了90.75%。以下是預測結果:
test_images\flower1.jpg
test_images\flower1.jpg is: pre labels:[0],name:['flower'] score: [ 1.]
test_images\flower2.jpg
test_images\flower2.jpg is: pre labels:[0],name:['flower'] score: [ 1.]
test_images\kittycat.jpg
test_images\kittycat.jpg is: pre labels:[6],name:['kittycat'] score: [ 0.4819051]
test_images\kittycat2.jpg
test_images\kittycat2.jpg is: pre labels:[6],name:['kittycat'] score: [ 0.4819051]
test_images\lion.jpg
test_images\lion.jpg is: pre labels:[6],name:['kittycat'] score: [ 0.4819051]
test_images\plane.jpg
test_images\plane.jpg is: pre labels:[4],name:['plane'] score: [ 1.]
test_images\plane2.jpg
test_images\plane2.jpg is: pre labels:[1],name:['guitar'] score: [ 1.]
test_images\tiger0.jpg
test_images\tiger0.jpg is: pre labels:[5],name:['tiger'] score: [ 1.]
test_images\tiger1.jpg
test_images\tiger1.jpg is: pre labels:[5],name:['tiger'] score: [ 1.]
還有改進的空間。