【深度學習-微調模型】使用Tensorflow Slim fine-tune(微調)模型

本文主要講解在現有常用模型基礎上,如何微調模型,減少訓練時間,同時保持模型檢測精度。

首先介紹下Slim這個Google公佈的圖像分類工具包,可在github鏈接:modules and examples built with tensorflow 中找到slim包。

上面這個鏈接目錄下主要包含:

official models(這個是用Tensorflow高層API做的例子模型集,建議初學者可嘗試);

research models(這個是很多研究者利用tensorflow做的模型集,這個不是官方提供的,是研究者個人在維護的);

samples folder (包含代碼片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代碼呈現);

而我說的slim工具包就在research文件夾下。


Slim庫結構

不僅定義了很多接口,還提供了很多ImageNet數據集上常用的網絡結構和預訓練模型(包括Alexnet,CycleGAN,DCGAN,VGG16,VGG19,Inception V1~V4,ResNet 50, ResNet 101,MobileNet V1等)。

 


下面用slim工具包中的文件來對自己的數據集做訓練,訓練可分爲利用已有的模型架構(如常見的VGG,Inception等的卷積,池化這些結構)來全新訓練權重文件或是微調權重文件。由於很多已有的imagenet圖像數據覆蓋面已經很廣,基於此訓練的網絡權重已經能提取大致的目標特徵(從低微像素到高維的結構特徵),所以可使用fine-tune只訓練框架中某些層的權重,當然根據自己數據集做全部權重重新訓練的檢測效果理論會更好些,需要權衡時間成本和檢測精度的需求了;

下面會依據成熟網絡結構Incvption V3分別做權重文件的全部重新訓練部分重新訓練(即fine-tune)來介紹;

(前提是你將slim工具庫下載下來,安裝了必要的tensorflow等框架;並且根據訓練圖像製作完成了tfrecord文件)

有關tfrecord訓練文件的製作請參考:將圖像製作成tfrecord

step1:定義新的datasets數據集文件

在slim/datasets/文件夾下 添加一個python文件,直接複製一份flowers.py,重命名爲“satellite.py”(這個名字可根據你實際的數據集名字來更改,我用的是何大神的航拍圖數據集)

需要對賦值生成後的satellite.py內容做如下修改:

_FILE_PATTERN = 'flowers_%s_*.tfrecord' 

更改爲

_FILE_PATTERN = 'satellite_%s_*.tfrecord'      #這個主要是根據你之前製作的tfrecord文件名來改的,我製作的訓練文件爲satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,驗證文件爲satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord

SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}

更改爲

SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}  #這個根據自己訓練和驗證樣本數量來改,我的訓練數據是800張圖/類,共6類,驗證集時200張/類,共6類;

_NUM_CLASSES = 5

更改爲

_NUM_CLASSES = 6       #實際訓練類別爲6類;

 

還需要對satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),這行代碼做更改,由於用的數據集源文件都是XXXX.jpg格式,因此將默認的圖像格式轉爲jpg,更改後爲'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,對satellite.py文件完成製作與更改(其源碼如下):

satellite.py

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.

The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf

from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'satellite_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}

_NUM_CLASSES = 6

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)

step2:註冊數據庫

接下來對slim/datasets/dataset_factory.py文件做更改,註冊下satellite數據庫;修改之處如下(添加了兩行紅色字體代碼):

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'satellite': satellite,
    
}

step3:準備訓練文件夾

在slim文件夾下新建如下目錄文件夾,並將對應的文件放在相應目錄下

slim/
    satellite/
              data/
                   satellite_train_00000-of-00002.tfrecord
                   satellite_train_00001-of-00002.tfrecord
                   satellite_validation_00000-of-00002.tfrecord
                   satellite_validation_00001-of-00002.tfrecord
                   label.txt
              pretrained/
                   inception_v3.ckpt
              train_dir/

data文件夾下存放你製作的tfrecord訓練測試文件和標籤名;

pretrained文件夾下存放官網訓練的權重文件;下載地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz      

train_dir文件夾下存放你訓練得到的模型和日誌;

step4-1:在現有模型結構上fine-tune

開始訓練,在slim文件夾下,運行如下指令可開始訓練(主要是訓練邏輯層):

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

命令參數解析如下:

--trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先來解 釋參數trainable_scopes 的作用,因爲非常重要。 trainable_scopes 規定了在模型中fine-tune變量的範圍 。 這裏的設定表示只對 InceptionV3/Logits, Inception V3/ AuxLogits 兩個變量進行微調,其他變量都保持不動 。 Inception V3/Logits,Inception V3/ AuxLogits 就相當於在網絡中的 fc8 ,它們是 Inception V3的“末端層” 。 如果不設定 trainable_scopes , 就會對模型中所有的參數進行訓練。

• --train_dir=satellite/train_dir:表明會在 satellite/train_dir目錄下保存日誌和checkpoint。

--dataset_name=satellite、 --dataset_split_ name=train: 指定訓練的數據集。

--dataset_dit=satellite/data:指定訓練數據集保存的位置。 

--model_ name=inception _ v3 :使用的模型名稱。 

--checkpoint_path=satellite/pretrained/inception_v3.ckpt:預訓練模型的保存位置。

--checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢復預訓練模型時,不恢復這兩層。正如之前所說,這兩層是 Inception V3 模型的末端層,對應着 ImageNet 數據集的 1000 類,和相當前的數據集不符,因此不要去恢復它。

--max_number_of_steps 100000:最大的執行步數。

--batch_size=32:每步使用的 batch 數量。

--learning_rate=0.001 : 學習率。

• --learning_rate_decay_type=fixed:學習率是否自動下降,此處使用固定的學習率。

• --save_interval_secs=300:每隔 300s,程序會把當前模型保存到train_dir中。 此處就是目錄 satellite/train_dir。

• --save_summaries_secs=2:每隔 2s,就會將日誌寫入到 train_dir 中。可以用 TensorBoard 查看該日誌。此處爲了方便觀察,設定的時間間隔較多,實際訓練時,爲了性能考慮,可以設定較長的時間間隔。

• --log_every_n_steps=10:每隔10步,就會在屏上打出訓練信息。

--optimizer=msprop:表示選定的優化器。

• --weight_decay=0.00004:選定的 weight_decay 值。 即模型中所高參數的 二次正則化超參數。


以上命令是隻訓練末端層 InceptionV3/Logits,Inception V3/ AuxLogits ,還 可以使用以下命令對所高層進行訓練:

step4-2:訓練整個模型權重數據

使用以下命令對所有層進行訓練:
去掉 了--trainable_scopes 參數

python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

當train_image_classifier.py程序啓動後,如果訓練文件夾(即satellite/train_dir)裏沒再已經保存的模型,就會加載 checkpoint_path 中的預訓練模型,緊接着,程序會把初始模型保存到 train_dir中 ,命名爲 model.ckpt-0, 0 表示第 0 步。 這之後,每隔 5min (參數一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序還會把當前模型保存到同樣的文件夾中 , 命名恪式和第一次保存的格式一樣。 因爲模型比較大,程序只會保留最新的 5 個模型。
此外,如果中斷了程序並再次運行,程序會首先檢查 train_dir 中有無已經保存的模型,如果有,就不會去加載 checkpoint_path 中的預訓練模型, 而是直接加載 train_dir 中已經訓練好的模型,並以此爲起點進行訓練。 Slim 之所以這樣設計,是爲了在微調網絡的時候,可以方便地按階段手動調整學習率等參數。
 

至此用slim工具包做fine-tune或重新訓練的步驟就完成了。


相似文章參考:https://blog.csdn.net/chaipp0607/article/details/74139895

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章