Caffe學習筆記系列4—基於改進的Saimese網絡的模型訓練和特徵提取
本小節主要講解改進的Siamese網絡的模型訓練,其實也算不上改進,主要講的還是如何準備改進的Siamese網絡的數據準備,對於特徵提取部分則不再介紹,跟前面的方法雷同。所謂的改進是在傳統的Siamese網絡的對比損失函數中添加softmax損失,及損失函數是對比損失和softmax損失的權重和。對於改進的Siamese網絡模型文件,本小節主要講將Caffe裏面的Siamese網絡進行改進,添加softmax損失函數,本部分也是針對該模型操作。首先感性認識一下添加了softmax損失函數的改進的Siamese網絡結構,如下:
圖1 改進的Siamese結構
此外,本人也設計了一個改進的Siamese網絡文件,在“Caffe學習筆記系列”文件夾—>“CaffeTest4”文件夾—>“MySiameseNet”文件夾中,主要借鑑了ALexNet網絡的參數設置,並進行了一些調整,該文件夾加密了的,如果需要的話請私。
但是該節還是針對圖1進行講解。這部分也是比較難的,一方面需要更改Caffe源碼,一方面要準備輸入數據格式。下面着重講解這兩部分。
一、修改源碼
爲了實現softmax和contrastive的結合,需要嘗試修改caffe源碼,包括accuracy_layer.cpp,contrastive_loss_layer.cpp,softmax_loss_layer.cpp等。
即修改三個涉及到label的cpp:
src/caffe/layers/accuracy_layer.cpp,
src/caffe/layers/contrastive_loss_layer.cpp,
src/caffe/layers/softmax_loss_layer.cpp,
修改的細節主要是對應的label處,具體見Caffe工程代碼。
二、創建訓練用的leveldb
寫些代碼生成訓練和測試leveldb用的兩個txt文件,如trainData.txt,valData.txt。生成後的txt內容如下:
Data/0/41_20160503071203/0_1.jpg 0 Data/0/81_20160503071250/0_15.jpg 0
Data/77/41_20160505072228/0_37.jpg 77 Data/56/81_20160505083311/0_23.jpg 56
… …
即每一行都有兩個圖像的名稱和它的label(取值1~93),比如在第一行中,第一個圖像名是Data/0/41_20160503071203/0_1.jpg,它的label是0,第二個圖像名是Data/77/41_20160505072228/0_37.jpg,它的label是77。
這裏Data/0/41_20160503071203/0_1.jpg和Data/0/81_20160503071250/0_15.jpg是相同label的contrastivepair,而Data/77/41_20160505072228/0_37.jpg
和Data/56/81_20160505083311/0_23.jpg是不同label的contrastivepair。如何生成該形式的.txt文件可以參考“CaffeTest4”文件夾下面的“getTrainTxt”工程。
接着,利用trainData.txt和valData.txt生成leveldb,通過“CaffeTest4”文件夾下面的“getSiameseNetFormatLeveled”工程,它讀取txt裏的信息並生成leveldb,這個工程有個主要的函數是ReadImageToDatum_double(),它是在include/caffe/util/io.hpp定義的(注:仿照已有的函數ReadImageToDatum()自定義的),實現是在src/caffe/util/io.cpp中。
由於使用shuffle方法,所以也修改了include/caffe/util/rng.hpp,即在rng.hpp中添加了template<class RandomAccessIterator, class RandomGenerator> inline voidshuffle_double()函數。
小結,以上爲了生成訓練用的leveldb,我們修改的caffe源碼文件有:
include/caffe/util/io.hpp
src/caffe/util/io.cpp
include/caffe/util/rng.hpp
特別的,我們對於ReadImageToDatum_double()生成leveldb時label的處理如下:
label = 第一張圖像的label*100+ 第二張圖像的label;
這樣的處理是爲了:
Softmax層,可以提取出第一張圖像的label;
contrastive層,可以提取出第一張和第二張圖像的label以判斷它們的label是不是相同
注:關於第一張圖像乘以的100,如果庫裏的人數大於100時,應該考慮修改此處label的計算公式。
生成leveldb 格式的數據可以參考“CaffeTest4”文件夾下面的“getSiameseNetFormatLeveled”工程。
三、網絡模型訓練
具體訓練部分不再講解,模型文件都在“CaffeTest4”。
該系列的代碼鏈接如下:https://pan.baidu.com/s/1Z314A-FJ57wXsaJqbdaAew 密碼:g3kv