基於vgg16的遷移學習,訓練自己的數據集(含預測結果)

1.vggNet簡介

vgg16是2014年由牛津大學提出的一個深度神經網絡模型,該模型在2014年的ILSVRC分類比賽中,取得了第二名的成績,而第一名當屬大名鼎鼎的googleNet,vggNet包含5種網絡類型,如下圖所示:

常見的有vgg16和vgg19。顧名思義vgg16有16層,包含13層卷積池化層和3層全連接層。而vgg19包含16層卷積池化層和3層全連接層。vggNet全部使用1x1,3x3的卷積核,而且vggNet證明了兩個3x3的卷積核可以等效爲一個5x5的卷積核,下圖示

                       


一張5x5的圖經兩個3x3的卷積核卷積後得到一張1x1的特徵圖,等效爲一個5x5的卷積核。同時在參數量上可以發現,5x5的卷積核的參數量是5x5=25,兩個3x3的卷積核是2x3x3=18,參數量是減少了的28%,同時由於與一個5x5的卷積核卷積只需一次非線性激活,而與兩個卷積核卷積可以進行兩次非線性激活變換,非線性表徵加強了,增加了CNN對特徵的學習能力。另外1x1卷積核能實現降維,增加非線性。


2.vgg16實現遷移學習

1.數據集準備,我使用8類數據,分別是truck,tiger,flower,kittycat,guitar,houses,plane,person,數據每類訓練集500張,驗證集300張

2.vgg16預訓練權重下載,我把它放在我的百度網盤裏了,密碼fwi4

3.生成train.txt,val.txt,label.txt

create_labels_files.py

# -*-coding:utf-8-*-

import os
import os.path

def write_txt(content, filename, mode='w'):
    """保存txt數據
    :param content:需要保存的數據,type->list
    :param filename:文件名
    :param mode:讀寫模式:'w' or 'a'
    :return: void
    """
    with open(filename, mode) as f:
        for line in content:
            str_line = ""
            for col, data in enumerate(line):
                if not col == len(line) - 1:
                    # 以空格作爲分隔符
                    str_line = str_line + str(data) + " "
                else:
                    # 每行最後一個數據用換行符“\n”
                    str_line = str_line + str(data) + "\n"
            f.write(str_line)


def get_files_list(dir):
    '''
    實現遍歷dir目錄下,所有文件(包含子文件夾的文件)
    :param dir:指定文件夾目錄
    :return:包含所有文件的列表->list
    '''
    # parent:父目錄, filenames:該目錄下所有文件夾,filenames:該目錄下的文件名
    files_list = []
    for parent, dirnames, filenames in os.walk(dir):
        for filename in filenames:
            print("parent is: " + parent)
            print("filename is: " + filename)
            # print(os.path.join(parent, filename))  # 輸出rootdir路徑下所有文件(包含子文件)信息
            curr_file = parent.split(os.sep)[-1]
            if curr_file == 'flower':
                labels = 0
            elif curr_file == 'guitar':
                labels = 1
            elif curr_file == 'person':
                labels = 2
            elif curr_file == 'houses':
                labels = 3
            elif curr_file == 'plane':
                labels = 4
            elif curr_file == 'tiger':
                labels = 5
            elif curr_file == 'kittycat':
                labels = 6
            elif curr_file == 'truck':
                labels = 7
            files_list.append([os.path.join(curr_file, filename), labels])
            print(files_list)
    return files_list


if __name__ == '__main__':
    train_dir = 'dataset/train'
    train_txt = 'dataset/train.txt'
    train_data = get_files_list(train_dir)
    write_txt(train_data, train_txt, mode='w')

    val_dir = 'dataset/val'
    val_txt = 'dataset/val.txt'
    val_data = get_files_list(val_dir)
    write_txt(val_data, val_txt, mode='w')

4.製作tf.record文件

create_tf_record.py

# -*-coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的屬性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成實數型的屬性
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def get_example_nums(tf_records_filenames):
    '''
    統計tf_records圖像的個數(example)個數
    :param tf_records_filenames: tf_records文件路徑
    :return:
    '''
    nums= 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums

def show_image(title,image):
    '''
    顯示圖片
    :param title: 圖像標題
    :param image: 圖像的數據
    :return:
    '''
    # plt.figure("show_image")
    # print(image.dtype)
    plt.imshow(image)
    plt.axis('on')    # 關掉座標軸爲 off
    plt.title(title)  # 圖像題目
    plt.show()

def load_labels_file(filename,labels_num=1,shuffle=False):
    '''
    載圖txt文件,文件中每行爲一個圖片信息,且以空格隔開:圖像路徑 標籤1 標籤2,如:test_image/1.jpg 0 2
    :param filename:
    :param labels_num :labels個數
    :param shuffle :是否打亂順序
    :return:images type->list
    :return:labels type->list
    '''
    images=[]
    labels=[]
    with open(filename) as f:
        lines_list=f.readlines()
        if shuffle:
            random.shuffle(lines_list)

        for lines in lines_list:
            line=lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(int(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels

def read_image(filename, resize_height, resize_width,normalization=False):
    '''
    讀取圖片數據,默認返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param normalization:是否歸一化到[0.,1.0]
    :return: 返回的圖片數據
    '''

    bgr_image = cv2.imread(filename)
    if None is bgr_image:
        pass
    elif len(bgr_image.shape)==2:#若是灰度圖則轉爲三通道
        print("Warning:gray image",filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
    print(filename)
    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉爲RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height>0 and resize_width>0:
        rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
    rgb_image=np.asanyarray(rgb_image)
    if normalization:
        # 不能寫成:rgb_image=rgb_image/255
        rgb_image=rgb_image/255.0
    # show_image("src resize image",image)
    return rgb_image


def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=True,num_threads=1):
    '''
    :param images:圖像
    :param labels:標籤
    :param batch_size:
    :param labels_nums:標籤個數
    :param one_hot:是否將labels轉爲one_hot的形式
    :param shuffle:是否打亂順序,一般train時shuffle=True,驗證時shuffle=False
    :return:返回batch的images和labels
    '''
    min_after_dequeue = 200
    capacity = min_after_dequeue + 3 * batch_size  # 保證capacity必須大於min_after_dequeue參數值
    if shuffle:
        images_batch, labels_batch = tf.train.shuffle_batch([images,labels],
                                                                    batch_size=batch_size,
                                                                    capacity=capacity,
                                                                    min_after_dequeue=min_after_dequeue,
                                                                    num_threads=num_threads)
    else:
        images_batch, labels_batch = tf.train.batch([images,labels],
                                                        batch_size=batch_size,
                                                        capacity=capacity,
                                                        num_threads=num_threads)
    if one_hot:
        labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
    return images_batch,labels_batch

def read_records(filename,resize_height, resize_width,type=None):
    '''
    解析record文件:源文件的圖像數據是RGB,uint8,[0,255],一般作爲訓練數據時,需要歸一化到[0,1]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param type:選擇圖像數據的返回類型
         None:默認將uint8-[0,255]轉爲float32-[0,255]
         normalization:歸一化float32-[0,1]
         centralization:歸一化float32-[0,1],再減均值中心化
    :return:
    '''
    # 創建文件隊列,不限讀取的數量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader從文件隊列中讀入一個序列化的樣本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符號化的樣本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數據

    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    tf_label = tf.cast(features['label'], tf.int32)
    # PS:恢復原始圖像數據,reshape的大小必須與保存之前的圖像shape一致,否則出錯
    # tf_image=tf.reshape(tf_image, [-1])    # 轉換爲行向量
    tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設置圖像的維度

    # 恢復數據後,纔可以對圖像進行resize_images:輸入uint->輸出float32
    # tf_image=tf.image.resize_images(tf_image,[224, 224])

    # 存儲的圖像類型爲uint8,tensorflow訓練時數據必須是tf.float32
    if type is None:
        tf_image = tf.cast(tf_image, tf.float32)
    elif type=='normalization':# [1]若需要歸一化請使用:
        # 僅當輸入數據是uint8,纔會歸一化[0,255]
        # tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0)  # 歸一化
    elif type=='centralization':
        # 若需要歸一化,且中心化,假設均值爲0.5,請使用:
        tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化

    # 這裏僅僅返回圖像和標籤
    # return tf_image, tf_height,tf_width,tf_depth,tf_label
    return tf_image,tf_label


def create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5):
    '''
    實現將圖像原始數據,label,長,寬等信息保存爲record文件
    注意:讀取的圖像數據默認是uint8,再轉爲tf的字符串型BytesList保存,解析請需要根據需要轉換類型
    :param image_dir:原始圖像的目錄
    :param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑)
    :param output_record_dir:保存record文件的路徑
    :param resize_height:
    :param resize_width:
    PS:當resize_height或者resize_width=0是,不執行resize
    :param shuffle:是否打亂順序
    :param log:log信息打印間隔
    '''
    # 加載文件,僅獲取一個label
    images_list, labels_list=load_labels_file(file,1,shuffle)

    writer = tf.python_io.TFRecordWriter(output_record_dir)
    for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
        image_path=os.path.join(image_dir,images_list[i])
        if not os.path.exists(image_path):
            print('Err:no image',image_path)
            continue
        image = read_image(image_path, resize_height, resize_width)
        image_raw = image.tostring()
        if i%log==0 or i==len(images_list)-1:
            print('------------processing:%d-th------------' % (i))
            print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))
        # 這裏僅保存一個label,多label適當增加"'label': _int64_feature(label)"項
        label=labels[0]
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'height': _int64_feature(image.shape[0]),
            'width': _int64_feature(image.shape[1]),
            'depth': _int64_feature(image.shape[2]),
            'label': _int64_feature(label)
        }))
        writer.write(example.SerializeToString())
    writer.close()

def disp_records(record_file,resize_height, resize_width,show_nums=4):
    '''
    解析record文件,並顯示show_nums張圖片,主要用於驗證生成record文件是否成功
    :param tfrecord_file: record文件路徑
    :return:
    '''
    # 讀取record函數
    tf_image, tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
    # 顯示前4個圖片
    init_op = tf.initialize_all_variables()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(show_nums):
            image,label = sess.run([tf_image,tf_label])  # 在會話中取出image和label
            # image = tf_image.eval()
            # 直接從record解析的image是一個向量,需要reshape顯示
            # image = image.reshape([height,width,depth])
            print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:%d"%(label),image)
        coord.request_stop()
        coord.join(threads)


def batch_test(record_file,resize_height, resize_width):
    '''
    :param record_file: record文件路徑
    :param resize_height:
    :param resize_width:
    :return:
    :PS:image_batch, label_batch一般作爲網絡的輸入
    '''
    # 讀取record函數
    tf_image,tf_label = read_records(record_file,resize_height,resize_width,type='normalization')
    image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:  # 開始一個會話
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在會話中取出images和labels
            images, labels = sess.run([image_batch, label_batch])
            # 這裏僅顯示每個batch裏第一張圖片
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))

        # 停止所有線程
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    # 參數設置

    resize_height = 224  # 指定存儲圖片高度
    resize_width = 224  # 指定存儲圖片寬度
    shuffle=True
    log=5
    # 產生train.record文件
    image_dir='dataset/train'
    train_labels = 'dataset/train.txt'  # 圖片路徑
    train_record_output = 'dataset/record/train{}.tfrecords'.format(resize_height)
    create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
    train_nums=get_example_nums(train_record_output)
    print("save train example nums={}".format(train_nums))

    # 產生val.record文件
    image_dir='dataset/val'
    val_labels = 'dataset/val.txt'  # 圖片路徑
    val_record_output = 'dataset/record/val{}.tfrecords'.format(resize_height)
    create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
    val_nums=get_example_nums(val_record_output)
    print("save val example nums={}".format(val_nums))

    # 測試顯示函數
    # disp_records(train_record_output,resize_height, resize_width)
    batch_test(train_record_output,resize_height, resize_width)

5.訓練模型

vgg16.py

#vgg16_train_and_val
import tensorflow as tf
import numpy as np
import pdb
import os
from datetime import datetime
from create_tf_record import *
import tensorflow.contrib.slim as slim

print("Tensorflow version:{}".format(tf.__version__))
labels_nums = 8  # 類別個數
batch_size = 1  #
resize_height = 224  # 指定存儲圖片高度
resize_width = 224  # 指定存儲圖片寬度
depths = 3
data_shape = [batch_size, resize_height, resize_width, depths]

# 定義input_images爲圖片數據
input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
# 定義input_labels爲labels數據
# input_labels = tf.placeholder(dtype=tf.int32, shape=[None], name='label')
input_labels = tf.placeholder(dtype=tf.int32, shape=[None, labels_nums], name='label')

# 定義dropout的概率
keep_prob = tf.placeholder(tf.float32,name='keep_prob')
is_training = tf.placeholder(tf.bool, name='is_training')
def net_evaluation(sess,loss,accuracy,val_images_batch,val_labels_batch,val_nums):
    val_max_steps = int(val_nums / batch_size)
    val_losses = []
    val_accs = []
    for _ in range(val_max_steps):
        val_x, val_y = sess.run([val_images_batch, val_labels_batch])
        # print('labels:',val_y)
        # val_loss = sess.run(loss, feed_dict={x: val_x, y: val_y, keep_prob: 1.0})
        # val_acc = sess.run(accuracy,feed_dict={x: val_x, y: val_y, keep_prob: 1.0})
        val_loss,val_acc = sess.run([loss,accuracy], feed_dict={input_images: val_x, input_labels: val_y, keep_prob:1.0, is_training: False})
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    mean_loss = np.array(val_losses, dtype=np.float32).mean()
    mean_acc = np.array(val_accs, dtype=np.float32).mean()
    return mean_loss, mean_acc

class Vgg16:
    vgg_mean = [103.939, 116.779, 123.68]   

    def __init__(self, vgg16_npy_path=None,input=None, restore_from=None):
        # pre-trained parameters
        try:
            self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
        except FileNotFoundError:
            print('Please download VGG16 parameters from here https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM\nOr from my Baidu Cloud: https://pan.baidu.com/s/1Spps1Wy0bvrQHH2IMkRfpg')

        # self.tfx = tf.placeholder(tf.float32, [None, 224, 224, 3])
        self.sess = tf.Session()
        self.tfx = input
        self.tfy = tf.placeholder(tf.float32, [None, 1])

        # Convert RGB to BGR
        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=self.tfx * 255.0)
        bgr = tf.concat(axis=3, values=[
            blue - self.vgg_mean[0],
            green - self.vgg_mean[1],
            red - self.vgg_mean[2],
        ])

        # pre-trained VGG layers are fixed in fine-tune
        conv1_1 = self.conv_layer(bgr, "conv1_1")
        conv1_2 = self.conv_layer(conv1_1, "conv1_2")
        pool1 = self.max_pool(conv1_2, 'pool1')

        conv2_1 = self.conv_layer(pool1, "conv2_1")
        conv2_2 = self.conv_layer(conv2_1, "conv2_2")
        pool2 = self.max_pool(conv2_2, 'pool2')

        conv3_1 = self.conv_layer(pool2, "conv3_1")
        conv3_2 = self.conv_layer(conv3_1, "conv3_2")
        conv3_3 = self.conv_layer(conv3_2, "conv3_3")
        pool3 = self.max_pool(conv3_3, 'pool3')

        conv4_1 = self.conv_layer(pool3, "conv4_1")
        conv4_2 = self.conv_layer(conv4_1, "conv4_2")
        conv4_3 = self.conv_layer(conv4_2, "conv4_3")
        pool4 = self.max_pool(conv4_3, 'pool4')

        conv5_1 = self.conv_layer(pool4, "conv5_1")
        conv5_2 = self.conv_layer(conv5_1, "conv5_2")
        conv5_3 = self.conv_layer(conv5_2, "conv5_3")
        pool5 = self.max_pool(conv5_3, 'pool5')

        # detach original VGG fc layers and
        # reconstruct your own fc layers serve for your own purpose
        pool5_shape = pool5.get_shape().as_list()
        nodes = pool5_shape[1] * pool5_shape[2] * pool5_shape[3]
        self.flatten = tf.reshape(pool5, [-1, nodes])
        self.fc6 = tf.layers.dense(self.flatten, 256, tf.nn.relu, name='fc6')
        self.out = tf.layers.dense(self.fc6, labels_nums, name='out')


    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):   # CNN's filter is constant, NOT Variable that can be trained
            conv = tf.nn.conv2d(bottom, self.data_dict[name][0], [1, 1, 1, 1], padding='SAME')
            lout = tf.nn.relu(tf.nn.bias_add(conv, self.data_dict[name][1]))
            return lout

    def train(self, x, y):
        loss, _ = self.sess.run([self.loss, self.train_op], {self.tfx: x, self.tfy: y})
        return loss



    def save(self, path='./model/'):
        saver = tf.train.Saver()
        saver.save(self.sess, path, write_meta_graph=False)

def train(train_record_file,
          train_log_step,
          train_param,
          val_record_file,
          val_log_step,
          labels_nums,
          data_shape,
          snapshot,
          snapshot_prefix):
    '''
    :param train_record_file: 訓練的tfrecord文件
    :param train_log_step: 顯示訓練過程log信息間隔
    :param train_param: train參數
    :param val_record_file: 驗證的tfrecord文件
    :param val_log_step: 顯示驗證過程log信息間隔
    :param val_param: val參數
    :param labels_nums: labels數
    :param data_shape: 輸入數據shape
    :param snapshot: 保存模型間隔
    :param snapshot_prefix: 保存模型文件的前綴名
    :return:
    '''
    [base_lr,max_steps]=train_param
    [batch_size,resize_height,resize_width,depths]=data_shape

    # 獲得訓練和測試的樣本數
    train_nums=get_example_nums(train_record_file)
    val_nums=get_example_nums(val_record_file)
    print('train nums:%d,val nums:%d'%(train_nums,val_nums))

    # 從record中讀取圖片和labels數據
    # train數據,訓練數據一般要求打亂順序shuffle=True
    train_images, train_labels = read_records(train_record_file, resize_height, resize_width, type='normalization')
    train_images_batch, train_labels_batch = get_batch_images(train_images, train_labels,
                                                              batch_size=batch_size, labels_nums=labels_nums,
                                                              one_hot=True, shuffle=False)
    # val數據,驗證數據可以不需要打亂數據
    val_images, val_labels = read_records(val_record_file, resize_height, resize_width, type='normalization')
    val_images_batch, val_labels_batch = get_batch_images(val_images, val_labels,
                                                          batch_size=batch_size, labels_nums=labels_nums,
                                                          one_hot=True, shuffle=False)

    # Define the model:
    # with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    #     out, end_points = inception_v3.inception_v3(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=keep_prob, is_training=is_training)
    vgg = Vgg16(vgg16_npy_path='./vgg16.npy',input=input_images)
    out = vgg.out
    # Specify the loss function: tf.losses定義的loss函數都會自動添加到loss函數,不需要add_loss()了
    tf.losses.softmax_cross_entropy(onehot_labels=input_labels, logits=out)#添加交叉熵損失loss=1.6
    # slim.losses.add_loss(my_loss)
    loss = tf.losses.get_total_loss(add_regularization_losses=True)#添加正則化損失loss=2.2
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(input_labels, 1)), tf.float32))
    # Specify the optimization scheme:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=base_lr)
    train_op = slim.learning.create_train_op(total_loss=loss,optimizer=optimizer)




    saver = tf.train.Saver()
    max_acc=0.0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(max_steps+1):
            batch_input_images, batch_input_labels = sess.run([train_images_batch, train_labels_batch])
            _, train_loss = sess.run([train_op, loss], feed_dict={input_images:batch_input_images,
                                                                      input_labels:batch_input_labels,
                                                                      keep_prob:0.5, is_training:True})
            # train測試(這裏僅測試訓練集的一個batch)
            if i%train_log_step == 0:
                train_acc = sess.run(accuracy, feed_dict={input_images:batch_input_images,
                                                          input_labels: batch_input_labels,
                                                          keep_prob:1.0, is_training: False})
                print("%s: Step [%d]  train Loss : %f, training accuracy :  %g" % (datetime.now(), i, train_loss, train_acc))

            # val測試(測試全部val數據)
            if i%val_log_step == 0:
                mean_loss, mean_acc=net_evaluation(sess, loss, accuracy, val_images_batch, val_labels_batch,val_nums)
                print("%s: Step [%d]  val Loss : %f, val accuracy :  %g" % (datetime.now(), i, mean_loss, mean_acc))

            # 模型保存:每迭代snapshot次或者最後一次保存模型
            if (i %snapshot == 0 and i >0)or i == max_steps:
                print('-----save:{}-{}'.format(snapshot_prefix,i))
                saver.save(sess, snapshot_prefix, global_step=i)
            # 保存val準確率最高的模型
            if mean_acc>max_acc and mean_acc>0.5:
                max_acc=mean_acc
                path = os.path.dirname(snapshot_prefix)
                best_models=os.path.join(path,'best_models_{}_{:.4f}.ckpt'.format(i,max_acc))
                print('------save:{}'.format(best_models))
                saver.save(sess, best_models)

        coord.request_stop()
        coord.join(threads)



if __name__ == '__main__':
    train_record_file='dataset/record/train224.tfrecords'
    val_record_file='dataset/record/val224.tfrecords'

    train_log_step=100
    base_lr = 0.01  # 學習率
    max_steps = 200000  # 迭代次數
    train_param=[base_lr,max_steps]

    val_log_step=200
    snapshot=2000#保存文件間隔
    snapshot_prefix='./models/model.ckpt'
    train(train_record_file=train_record_file,
          train_log_step=train_log_step,
          train_param=train_param,
          val_record_file=val_record_file,
          val_log_step=val_log_step,
          labels_nums=labels_nums,
          data_shape=data_shape,
          snapshot=snapshot,
          snapshot_prefix=snapshot_prefix)

3結果顯示

用實驗室服務器訓練了20萬代,在驗證集上的準確率達到了90.75%。以下是預測結果:

test_images\flower1.jpg
test_images\flower1.jpg is: pre labels:[0],name:['flower'] score: [ 1.]
test_images\flower2.jpg
test_images\flower2.jpg is: pre labels:[0],name:['flower'] score: [ 1.]
test_images\kittycat.jpg
test_images\kittycat.jpg is: pre labels:[6],name:['kittycat'] score: [ 0.4819051]
test_images\kittycat2.jpg
test_images\kittycat2.jpg is: pre labels:[6],name:['kittycat'] score: [ 0.4819051]
test_images\lion.jpg
test_images\lion.jpg is: pre labels:[6],name:['kittycat'] score: [ 0.4819051]
test_images\plane.jpg
test_images\plane.jpg is: pre labels:[4],name:['plane'] score: [ 1.]
test_images\plane2.jpg
test_images\plane2.jpg is: pre labels:[1],name:['guitar'] score: [ 1.]
test_images\tiger0.jpg
test_images\tiger0.jpg is: pre labels:[5],name:['tiger'] score: [ 1.]
test_images\tiger1.jpg
test_images\tiger1.jpg is: pre labels:[5],name:['tiger'] score: [ 1.]

還有改進的空間。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章