Caffe學習筆記系列3—基於Saimese網絡的模型訓練和特徵提取

Caffe學習筆記系列3—基於Saimese網絡的模型訓練和特徵提取

        本節主要講解Siamese網絡的模型訓練和特徵提取,其中特徵提取將不再講述,和Caffe學習筆記系列2是一樣的。本節主要是講解如何訓練Siamese網絡模型以及如何準備該網絡需要的輸入數據格式。

         在“Caffe學習筆記系列”文件夾中建立“CaffeTest3”文件夾,本節的所有操作在該文件夾進行。

        本節分爲兩部分,第一部分是採用Siamese網絡訓練mnist數據集,第二部分是採用Siamese網絡訓練自己標註的數據,其中第一部分較爲簡單,第二部分的難點在於如何準備輸入數據的格式。下面先簡要介紹第一部分。

一、採用Siamese網絡訓練minst數據集

        本部分的資料在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾—>“SiameseNetMnist”文件夾中。

        該部分主要包括以下幾個步驟:

1、  在“CaffeTest3”文件夾裏面建立“SiameseNetMnist”文件夾,訓練mnist數據集在該文件夾下操作;

2、下載mnist數據集,將其放在“mnist”文件夾裏面,將Caffe中的mnist_siamese.prototxt、mnist_siamese_solver.prototxt、mnist_siamese_train_test.prototxt放到“SiameseNetMnist”文件夾;

3、  建立create_mnist_train_leveldeb.bat批處理文件並運行得到訓練數據,裏面編寫如下代碼:

..\..\CaffeDev\caffemaster\Build\x64\Release\convert_mnist_siamese_data.exe./mnist/train-images-idx3-ubyte ./mnist/train-labels-idx1-ubytetrainleveldb 
pause

4、  建立create_mnist_test_leveldb.bat批處理文件並運行得到測試數據,裏面編寫如下代碼:

..\..\CaffeDev\caffemaster\Build\x64\Release\convert_mnist_siamese_data.exe./mnist/t10k-images-idx3-ubyte ./mnist/t10k-labels-idx1-ubyte testleveldb 
pause

5、  建立train.bat批處理文件並運行,開始訓練,裏面編寫如下代碼:

..\..\CaffeDev\caffe-master\Build\x64\Release\caffe.exetrain --solver=mnist_siamese_solver.prototxt
pause 


二、採用Siamese網絡訓練minst數據集

        本部分的資料在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾—>“SiameseNetMyData”文件夾中。

        這一部分主要講解如何利用Siamese網絡訓練自己的數據,該部分比較棘手的地方在於如何得到訓練的數據格式。Siamese網絡的目的是使得相同類別的數據儘可能近,不同類別的數據儘可能遠。輸入的數據形如“Data/0/41_20160503071203/0_0.jpg(圖片1)  0 (圖片1的類別)  Data/0/81_20160503071250/0_19.jpg(圖片2)   0(圖片2的類別)”,這時可以在程序中通過判斷圖片1的類別和圖片2的類別來決定兩張輸入圖片是否是同一類別,但是Caffe中不支持這種形式,需要對Caffe源碼進行改動,下一系列對這種改動進行講解;

        Caffe中支持的Siamese形式如下:

“Data/0/41_20160503071203/0_0.jpg(圖片1)   Data/0/81_20160503071250/0_19.jpg(圖片2)   0/1”,如果圖片1和圖片2類別相同則爲0,否則爲1。下面着重講解第二種格式。

1、  在“CaffeTest3”文件夾裏面建立“SiameseNetMyData”文件夾,下面訓練自己的數據集在該文件夾下操作;

2、 在該文件夾中建立“Data”文件夾,裏面的數據和系列1中的數據一樣;

3、 將Data文件中的數據生成形如“Data/0/41_20160503071203/0_1.jpgData/0/81_20160503071250/0_10.jpg  0”的txt數據格式;在此提供一份C++代碼,見附1

4、將txt裏面的數據生成leveld格式,該部分的代碼見附2

5、建立批處理文件train.bat,並運行之,裏面編寫如下代碼:

..\..\CaffeDev\caffe-master\Build\x64\Release\caffe.exe train--solver=mnist_siamese_solver.prototxt
pause

注意:mnist_siamese_solver.prototxt中的

slice_dim: 1

         #RGB--3

         #Gray--1

slice_point: 3  

        到此基本上把Siamese網絡需要的數據格式輸入完畢。當然本節講述的Siamese網絡是Caffe自帶的,本人也對AlexNet網絡進行了調整,利用AlexNet網絡的層參數設置將其組合成Siamese網絡,詳見“CaffeTest3”文件夾下面的“MySiameseNet”文件夾。本人設計的Siamese網絡文件在在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾—>“MySiameseNet”文件夾中,該文件夾加密了的,如果需要請私。

附1:

#include<vector>

#include<opencv2\opencv.hpp>

#include<iostream>

#include<string>

#include<fstream>

#include<direct.h>

#include<time.h>

 

usingnamespace std;

usingnamespace cv;

 

//生成Siamese網絡訓練的txt數據格式

//圖片存放的目錄如下:Data/類別/相機號(41,81)/xx.jpg

//得到的訓練集的txt格式:Data/類別/相機號/xx.jpg Data/類別/相機號/xx.jpg  0/1

//訓練集:驗證集=5:1

//samepairs:different pairs=1:10
void main()
{

         int trainToval = 5;

         int diffTosame = 10;

         string trainData ="../../trainData.txt";

         string valData ="../../valData.txt";

         ofstream trainOut(trainData);

         ofstream valOut(valData);

         srand((unsigned)time(NULL));

         int numstart = 0, numend = 93;//類別;

         for (int i = numstart; i < numend;i++)

         {

                   string mainFolder ="../../Data/";

                   mainFolder = mainFolder +to_string(i);

                  

                   Directory dir;

                   string exten = "*";

                   bool addPath = true;

                   vector<string>filenames = dir.GetListFolders(mainFolder, exten, addPath);

 

                  for(int j = 0; j < filenames.size(); j++)

                   {

                            vector<string>tmp = dir.GetListFiles(filenames[j], "*.jpg", true);

                            for (int k = 0; k< tmp.size(); k++)

                            {

                                     //生成正樣本對

                                     vector<string>otherSame;//相同類別在不同相機編號下對應的圖片

                                     for (int j1= 0; j1 < filenames.size(); j1++)

                                     {

                                               if(j1 == j) continue;

                                               vector<string>tmp1 = dir.GetListFiles(filenames[j1], "*.jpg", true);

                                               for(int k1 = 0; k1 < tmp1.size(); k1++)

                                                        otherSame.push_back(tmp1[k1]);

                                     }

                                     int posnum= (rand() % (otherSame.size() - 0)) + 0;

                                     //生成負樣本對

                                     vector<string>otherDiff;

                                     for (int i1= 0; i1 < numend; i1++)

                                     {

                                               if(i1 == i) continue;

                                               mainFolder= "../../Data/" + to_string(i1);

                                               vector<string>others = dir.GetListFolders(mainFolder, exten, addPath);

                                               for(int j2 = 0; j2 < others.size(); j2++)

                                               {

                                                        vector<string>tmp2 = dir.GetListFiles(others[j2], "*.jpg", true);

                                                        for(int k2 = 0; k2 < tmp2.size(); k2++)

                                                                 otherDiff.push_back(tmp2[k2]);

                                               }

                                     }

 

                                     if (k % trainToval!= 0)//訓練集:驗證集=5:1

                                     {

                                               trainOut<< tmp[k] << " "<< otherSame[posnum] <<" " << 0 << endl;//正樣本對

                                               for(int n = 0; n < diffTosame; n++)

                                               {

                                                        intnegnum = (rand() % (otherDiff.size() - 0)) + 0;

                                                        //===============================

                                                        intindex1 = 0, index2 = 0;

                                                        intcount = 0;

                                                        for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)

                                                        {

                                                                 if(otherDiff[negnum][i] == '/')

                                                                           count++;

                                                                 if(count == 2)

                                                                 {

                                                                           index2= i; break;

                                                                 }

                                                        }

                                                        count= 0;

                                                        for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)

                                                        {

                                                                 if(otherDiff[negnum][i] == '/')

                                                                           count++;

                                                                 if(count == 3)

                                                                 {

                                                                           index1= i; break;

                                                                 }

                                                        }

                                                        stringlabelsubstr = otherDiff[negnum].substr(index1 + 1, index2 - index1 - 1);

                                                        intlabel = atoi(labelsubstr.c_str());

                                                        //==============================

                                                        trainOut<< tmp[k] << " " << otherDiff[negnum] <<" " << 1 << endl;//負樣本對

                                               }

                                     }

                                     else//驗證集

                                     {

                                               valOut << tmp[k]<< " " << otherSame[posnum] << " "<< 0 << endl;//正樣本對

                                               for(int n = 0; n < diffTosame; n++)

                                               {

                                                        intnegnum = (rand() % (otherDiff.size() - 0)) + 0;

                                                        //===============================

                                                        intindex1 = 0, index2 = 0;

                                                        intcount = 0;

                                                        for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)

                                                        {

                                                                 if(otherDiff[negnum][i] == '/')

                                                                           count++;

                                                                 if(count == 2)

                                                                 {

                                                                           index2= i; break;

                                                                 }

                                                        }

                                                        count= 0;

                                                        for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)

                                                        {

                                                                 if(otherDiff[negnum][i] == '/')

                                                                           count++;

                                                                 if(count == 3)

                                                                 {

                                                                           index1= i; break;

                                                                 }

                                                        }

                                                        stringlabelsubstr = otherDiff[negnum].substr(index1 + 1, index2 - index1 - 1);

                                                        intlabel = atoi(labelsubstr.c_str());

                                                        //==============================

                                                        valOut<< tmp[k] << " " << otherDiff[negnum] <<" " << 1 << endl;//負樣本對

                                               }

                                     }

                            }                          

                   }

         }

}

附2

#include"getSiameseNetInputFormat.h"

DEFINE_bool(gray,false, "when this option is on, treat images as grayscale ones");

DEFINE_bool(shuffle,false, "randomly shuffle the order of images and their labels");

DEFINE_string(backend,"leveldb", "the backend {lmdb, leveldb} for storing theresult");

DEFINE_int32(resize_width,100, "Width images are resized to");//=====需要調整

DEFINE_int32(resize_height,100, "Height images are resized to");//===需要調整

DEFINE_bool(check_size,false, "When this option is on, check that all the datum have the samesize");

DEFINE_bool(encoded,false, "When this option is on, the encoded image will be save indatum");

DEFINE_string(encode_type,"", "Optional: What type should we encode the image as('png','jpg',...).");

DEFINE_int32(channel,3, "channel numbers of the image");

staticbool ReadImageToMemory(const string &FileName, const int Height, const intWidth, char *Pixels)
{

         cv::Mat OriginImage =cv::imread(FileName);

 

         CHECK(OriginImage.data) <<"Failed to read the image.\n";

 

         cv::Mat ResizeImage;

         cv::resize(OriginImage, ResizeImage,cv::Size(Width, Height));

         CHECK(ResizeImage.rows == Height)<< "The heighs of Image is no equal to the input height.\n";

         CHECK(ResizeImage.cols == Width)<< "The width of Image is no equal to the input width.\n";

         CHECK(ResizeImage.channels() == 3)<< "The channel of Image is no equal to three.\n";

 

         for (int HeightIndex = 0; HeightIndex< Height; ++HeightIndex)

         {

                   const uchar* ptr =ResizeImage.ptr<uchar>(HeightIndex);

                   int img_index = 0;

                   for (int WidthIndex = 0;WidthIndex < Width; ++WidthIndex)

                   {

                            for (intChannelIndex = 0; ChannelIndex < ResizeImage.channels(); ++ChannelIndex)

                            {

                                     intdatum_index = (ChannelIndex * Height + HeightIndex) * Width + WidthIndex;

                                     *(Pixels +datum_index) = static_cast<char>(ptr[img_index++]);

                            }

                   }

         }

         return true;

}

intgetSiameseNetInputFormat()
{
#ifndefGFLAGS_GFLAGS_H_

         namespace gflags = google;

#endif

         gflags::SetUsageMessage("Convert aset of color images to the leveldb\n"

                   "format used as inputfor Caffe.\n"

                   "Usage:\n"

                   "    convert_imageset [FLAGS] ROOTFOLDER/LISTFILE DB_NAME\n");

 

         //caffe::GlobalInit(&ac, av);

         // 讀取圖像名字和標籤

         std::ifstreaminfile("../../trainData.txt");//"../../valData.txt"

         std::vector<std::pair<std::string,std::string> > lines;

         std::string filename;

         std::string pairname;

         int label;

         std::vector<int> labels;

         while (infile >> filename>> pairname >> label)

         {

                   string filename1 ="../../" + filename;

                   string pairname1 ="../../" + pairname;

                   lines.push_back(std::make_pair(filename1,pairname1));

                   labels.push_back(label);

         }

 

         // 打亂圖片順序

         if (FLAGS_shuffle)

         {

                   LOG(INFO) <<"Shuffling data";

                   shuffle(lines.begin(),lines.end());

         }

         LOG(INFO) << "A total of" << lines.size() << " images.";

 

         //設置圖像的高度和寬度

         int resize_height = std::max<int>(0,FLAGS_resize_height);

         int resize_width =std::max<int>(0, FLAGS_resize_width);

         int channel = std::max<int>(1,FLAGS_channel);

         //打開數據庫

         leveldb::DB* db;

         leveldb::Options options;

         options.create_if_missing = true;

         options.error_if_exists = true;

         leveldb::Status status =leveldb::DB::Open(options, "../../train_leveldb", &db);//"../../val_leveldb"

         CHECK(status.ok()) <<"Failed to open leveldb " << "../../train_leveldb"<< ". Is it already existing?";// "../../val_leveldb"

 

         //保存到leveldb

         char* Pixels = new char[2 *resize_height * resize_width * channel];

         const int kMaxKeyLength = 10;

         char key[kMaxKeyLength];

         std::string value;

         caffe::Datum datum;

         datum.set_channels(2 * channel);

         datum.set_height(resize_height);

         datum.set_width(resize_width);

         for (int LineIndex = 0; LineIndex <lines.size(); LineIndex++)

         {            

                   char* FirstImagePixel =Pixels;

                   ReadImageToMemory(lines[LineIndex].first,resize_height, resize_width, FirstImagePixel);

 

                   char *SecondImagePixel =Pixels + resize_width * resize_height * channel;

                   ReadImageToMemory(lines[LineIndex].second,resize_height, resize_width, SecondImagePixel);

                   datum.set_data(Pixels, 2 *resize_height * resize_width * channel);

                   datum.set_label(labels[LineIndex]);

 

                   datum.SerializeToString(&value);

                   int key_value =(int)(LineIndex);

                   _snprintf(key, kMaxKeyLength,"%08d", key_value);

                   string keystr(key);

                   cout << "label:" << datum.label() << ' ' << "key index: "<< keystr << endl;

                   db->Put(leveldb::WriteOptions(),std::string(key), value);

         }
         delete db;
         delete[] Pixels;

         return 0;
}

提示:本小節所有資料在“Caffe學習筆記系列”文件夾—>“CaffeTest3”文件夾中。

該系列的代碼鏈接如下:https://pan.baidu.com/s/1kd7ATJyoF_Dhlnx_9IIa_Q 密碼:6vgq


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章