Caffe學習筆記系列3—基於Saimese網絡的模型訓練和特徵提取
本節主要講解Siamese網絡的模型訓練和特徵提取,其中特徵提取將不再講述,和Caffe學習筆記系列2是一樣的。本節主要是講解如何訓練Siamese網絡模型以及如何準備該網絡需要的輸入數據格式。
在“Caffe學習筆記系列”文件夾中建立“CaffeTest3”文件夾,本節的所有操作在該文件夾進行。
本節分爲兩部分,第一部分是採用Siamese網絡訓練mnist數據集,第二部分是採用Siamese網絡訓練自己標註的數據,其中第一部分較爲簡單,第二部分的難點在於如何準備輸入數據的格式。下面先簡要介紹第一部分。
一、採用Siamese網絡訓練minst數據集
本部分的資料在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾—>“SiameseNetMnist”文件夾中。
該部分主要包括以下幾個步驟:
1、 在“CaffeTest3”文件夾裏面建立“SiameseNetMnist”文件夾,訓練mnist數據集在該文件夾下操作;
2、下載mnist數據集,將其放在“mnist”文件夾裏面,將Caffe中的mnist_siamese.prototxt、mnist_siamese_solver.prototxt、mnist_siamese_train_test.prototxt放到“SiameseNetMnist”文件夾;
3、 建立create_mnist_train_leveldeb.bat批處理文件並運行得到訓練數據,裏面編寫如下代碼:
..\..\CaffeDev\caffemaster\Build\x64\Release\convert_mnist_siamese_data.exe./mnist/train-images-idx3-ubyte ./mnist/train-labels-idx1-ubytetrainleveldb
pause
4、 建立create_mnist_test_leveldb.bat批處理文件並運行得到測試數據,裏面編寫如下代碼:
..\..\CaffeDev\caffemaster\Build\x64\Release\convert_mnist_siamese_data.exe./mnist/t10k-images-idx3-ubyte ./mnist/t10k-labels-idx1-ubyte testleveldb
pause
5、 建立train.bat批處理文件並運行,開始訓練,裏面編寫如下代碼:
..\..\CaffeDev\caffe-master\Build\x64\Release\caffe.exetrain --solver=mnist_siamese_solver.prototxt
pause
二、採用Siamese網絡訓練minst數據集
本部分的資料在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾—>“SiameseNetMyData”文件夾中。
這一部分主要講解如何利用Siamese網絡訓練自己的數據,該部分比較棘手的地方在於如何得到訓練的數據格式。Siamese網絡的目的是使得相同類別的數據儘可能近,不同類別的數據儘可能遠。輸入的數據形如“Data/0/41_20160503071203/0_0.jpg(圖片1) 0 (圖片1的類別) Data/0/81_20160503071250/0_19.jpg(圖片2) 0(圖片2的類別)”,這時可以在程序中通過判斷圖片1的類別和圖片2的類別來決定兩張輸入圖片是否是同一類別,但是Caffe中不支持這種形式,需要對Caffe源碼進行改動,下一系列對這種改動進行講解;
Caffe中支持的Siamese形式如下:
“Data/0/41_20160503071203/0_0.jpg(圖片1) Data/0/81_20160503071250/0_19.jpg(圖片2) 0/1”,如果圖片1和圖片2類別相同則爲0,否則爲1。下面着重講解第二種格式。
1、 在“CaffeTest3”文件夾裏面建立“SiameseNetMyData”文件夾,下面訓練自己的數據集在該文件夾下操作;
2、 在該文件夾中建立“Data”文件夾,裏面的數據和系列1中的數據一樣;
3、 將Data文件中的數據生成形如“Data/0/41_20160503071203/0_1.jpgData/0/81_20160503071250/0_10.jpg 0”的txt數據格式;在此提供一份C++代碼,見附1;
4、將txt裏面的數據生成leveld格式,該部分的代碼見附2;
5、建立批處理文件train.bat,並運行之,裏面編寫如下代碼:
..\..\CaffeDev\caffe-master\Build\x64\Release\caffe.exe train--solver=mnist_siamese_solver.prototxt
pause
注意:mnist_siamese_solver.prototxt中的
slice_dim: 1
#RGB--3
#Gray--1
slice_point: 3
到此基本上把Siamese網絡需要的數據格式輸入完畢。當然本節講述的Siamese網絡是Caffe自帶的,本人也對AlexNet網絡進行了調整,利用AlexNet網絡的層參數設置將其組合成Siamese網絡,詳見“CaffeTest3”文件夾下面的“MySiameseNet”文件夾。本人設計的Siamese網絡文件在在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾—>“MySiameseNet”文件夾中,該文件夾加密了的,如果需要請私。
附1:
#include<vector>
#include<opencv2\opencv.hpp>
#include<iostream>
#include<string>
#include<fstream>
#include<direct.h>
#include<time.h>
usingnamespace std;
usingnamespace cv;
//生成Siamese網絡訓練的txt數據格式
//圖片存放的目錄如下:Data/類別/相機號(41,81)/xx.jpg
//得到的訓練集的txt格式:Data/類別/相機號/xx.jpg Data/類別/相機號/xx.jpg 0/1
//訓練集:驗證集=5:1
//samepairs:different pairs=1:10
void main()
{
int trainToval = 5;
int diffTosame = 10;
string trainData ="../../trainData.txt";
string valData ="../../valData.txt";
ofstream trainOut(trainData);
ofstream valOut(valData);
srand((unsigned)time(NULL));
int numstart = 0, numend = 93;//類別;
for (int i = numstart; i < numend;i++)
{
string mainFolder ="../../Data/";
mainFolder = mainFolder +to_string(i);
Directory dir;
string exten = "*";
bool addPath = true;
vector<string>filenames = dir.GetListFolders(mainFolder, exten, addPath);
for(int j = 0; j < filenames.size(); j++)
{
vector<string>tmp = dir.GetListFiles(filenames[j], "*.jpg", true);
for (int k = 0; k< tmp.size(); k++)
{
//生成正樣本對
vector<string>otherSame;//相同類別在不同相機編號下對應的圖片
for (int j1= 0; j1 < filenames.size(); j1++)
{
if(j1 == j) continue;
vector<string>tmp1 = dir.GetListFiles(filenames[j1], "*.jpg", true);
for(int k1 = 0; k1 < tmp1.size(); k1++)
otherSame.push_back(tmp1[k1]);
}
int posnum= (rand() % (otherSame.size() - 0)) + 0;
//生成負樣本對
vector<string>otherDiff;
for (int i1= 0; i1 < numend; i1++)
{
if(i1 == i) continue;
mainFolder= "../../Data/" + to_string(i1);
vector<string>others = dir.GetListFolders(mainFolder, exten, addPath);
for(int j2 = 0; j2 < others.size(); j2++)
{
vector<string>tmp2 = dir.GetListFiles(others[j2], "*.jpg", true);
for(int k2 = 0; k2 < tmp2.size(); k2++)
otherDiff.push_back(tmp2[k2]);
}
}
if (k % trainToval!= 0)//訓練集:驗證集=5:1
{
trainOut<< tmp[k] << " "<< otherSame[posnum] <<" " << 0 << endl;//正樣本對
for(int n = 0; n < diffTosame; n++)
{
intnegnum = (rand() % (otherDiff.size() - 0)) + 0;
//===============================
intindex1 = 0, index2 = 0;
intcount = 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 2)
{
index2= i; break;
}
}
count= 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 3)
{
index1= i; break;
}
}
stringlabelsubstr = otherDiff[negnum].substr(index1 + 1, index2 - index1 - 1);
intlabel = atoi(labelsubstr.c_str());
//==============================
trainOut<< tmp[k] << " " << otherDiff[negnum] <<" " << 1 << endl;//負樣本對
}
}
else//驗證集
{
valOut << tmp[k]<< " " << otherSame[posnum] << " "<< 0 << endl;//正樣本對
for(int n = 0; n < diffTosame; n++)
{
intnegnum = (rand() % (otherDiff.size() - 0)) + 0;
//===============================
intindex1 = 0, index2 = 0;
intcount = 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 2)
{
index2= i; break;
}
}
count= 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 3)
{
index1= i; break;
}
}
stringlabelsubstr = otherDiff[negnum].substr(index1 + 1, index2 - index1 - 1);
intlabel = atoi(labelsubstr.c_str());
//==============================
valOut<< tmp[k] << " " << otherDiff[negnum] <<" " << 1 << endl;//負樣本對
}
}
}
}
}
}
附2:
#include"getSiameseNetInputFormat.h"
DEFINE_bool(gray,false, "when this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle,false, "randomly shuffle the order of images and their labels");
DEFINE_string(backend,"leveldb", "the backend {lmdb, leveldb} for storing theresult");
DEFINE_int32(resize_width,100, "Width images are resized to");//=====需要調整
DEFINE_int32(resize_height,100, "Height images are resized to");//===需要調整
DEFINE_bool(check_size,false, "When this option is on, check that all the datum have the samesize");
DEFINE_bool(encoded,false, "When this option is on, the encoded image will be save indatum");
DEFINE_string(encode_type,"", "Optional: What type should we encode the image as('png','jpg',...).");
DEFINE_int32(channel,3, "channel numbers of the image");
staticbool ReadImageToMemory(const string &FileName, const int Height, const intWidth, char *Pixels)
{
cv::Mat OriginImage =cv::imread(FileName);
CHECK(OriginImage.data) <<"Failed to read the image.\n";
cv::Mat ResizeImage;
cv::resize(OriginImage, ResizeImage,cv::Size(Width, Height));
CHECK(ResizeImage.rows == Height)<< "The heighs of Image is no equal to the input height.\n";
CHECK(ResizeImage.cols == Width)<< "The width of Image is no equal to the input width.\n";
CHECK(ResizeImage.channels() == 3)<< "The channel of Image is no equal to three.\n";
for (int HeightIndex = 0; HeightIndex< Height; ++HeightIndex)
{
const uchar* ptr =ResizeImage.ptr<uchar>(HeightIndex);
int img_index = 0;
for (int WidthIndex = 0;WidthIndex < Width; ++WidthIndex)
{
for (intChannelIndex = 0; ChannelIndex < ResizeImage.channels(); ++ChannelIndex)
{
intdatum_index = (ChannelIndex * Height + HeightIndex) * Width + WidthIndex;
*(Pixels +datum_index) = static_cast<char>(ptr[img_index++]);
}
}
}
return true;
}
intgetSiameseNetInputFormat()
{
#ifndefGFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Convert aset of color images to the leveldb\n"
"format used as inputfor Caffe.\n"
"Usage:\n"
" convert_imageset [FLAGS] ROOTFOLDER/LISTFILE DB_NAME\n");
//caffe::GlobalInit(&ac, av);
// 讀取圖像名字和標籤
std::ifstreaminfile("../../trainData.txt");//"../../valData.txt"
std::vector<std::pair<std::string,std::string> > lines;
std::string filename;
std::string pairname;
int label;
std::vector<int> labels;
while (infile >> filename>> pairname >> label)
{
string filename1 ="../../" + filename;
string pairname1 ="../../" + pairname;
lines.push_back(std::make_pair(filename1,pairname1));
labels.push_back(label);
}
// 打亂圖片順序
if (FLAGS_shuffle)
{
LOG(INFO) <<"Shuffling data";
shuffle(lines.begin(),lines.end());
}
LOG(INFO) << "A total of" << lines.size() << " images.";
//設置圖像的高度和寬度
int resize_height = std::max<int>(0,FLAGS_resize_height);
int resize_width =std::max<int>(0, FLAGS_resize_width);
int channel = std::max<int>(1,FLAGS_channel);
//打開數據庫
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status =leveldb::DB::Open(options, "../../train_leveldb", &db);//"../../val_leveldb"
CHECK(status.ok()) <<"Failed to open leveldb " << "../../train_leveldb"<< ". Is it already existing?";// "../../val_leveldb"
//保存到leveldb
char* Pixels = new char[2 *resize_height * resize_width * channel];
const int kMaxKeyLength = 10;
char key[kMaxKeyLength];
std::string value;
caffe::Datum datum;
datum.set_channels(2 * channel);
datum.set_height(resize_height);
datum.set_width(resize_width);
for (int LineIndex = 0; LineIndex <lines.size(); LineIndex++)
{
char* FirstImagePixel =Pixels;
ReadImageToMemory(lines[LineIndex].first,resize_height, resize_width, FirstImagePixel);
char *SecondImagePixel =Pixels + resize_width * resize_height * channel;
ReadImageToMemory(lines[LineIndex].second,resize_height, resize_width, SecondImagePixel);
datum.set_data(Pixels, 2 *resize_height * resize_width * channel);
datum.set_label(labels[LineIndex]);
datum.SerializeToString(&value);
int key_value =(int)(LineIndex);
_snprintf(key, kMaxKeyLength,"%08d", key_value);
string keystr(key);
cout << "label:" << datum.label() << ' ' << "key index: "<< keystr << endl;
db->Put(leveldb::WriteOptions(),std::string(key), value);
}
delete db;
delete[] Pixels;
return 0;
}
提示:本小節所有資料在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾中。
該系列的代碼鏈接如下:https://pan.baidu.com/s/1kd7ATJyoF_Dhlnx_9IIa_Q 密碼:6vgq