版權聲明:未經允許請勿用於商業用途,轉載請註明出處:http://blog.csdn.net/mounty_fsc/
本文地址:http://blog.csdn.net/mounty_fsc/article/details/51090114
在訓練lenet的train_lenet.sh
中內容爲:
./build/tools/caffe train –solver=examples/mnist/lenet_solver.prototxt
由此可知,訓練網咯模型是由tools/caffe.cpp
生成的工具caffe
在模式train
下完成的。
初始化過程總的來說,從main()
、train()
中創建Solver
,在Solver
中創建Net
,在Net
中創建Layer
.
1 程序入口
- 找到
caffe.cpp
的main
函數中,通過GetBrewFunction(caffe::string(argv[1]))()
調用執行train()
函數。 train中
,通過參數-examples/mnist/lenet_solver.prototxt
把solver
參數讀入solver_param
中。-
隨後註冊並定義
solver
的指針(見第2節)shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))
- 1
- 2
- 1
- 2
-
調用
solver
的Solver()
方法。多個GPU涉及到GPU間帶異步處理問題(見第3節)if (gpus.size() > 1) { caffe::P2PSync<float> sync(solver, NULL, solver->param()); sync.run(gpus); } else { LOG(INFO) << "Starting Optimization"; solver->Solve(); }
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 1
- 2
- 3
- 4
- 5
- 6
- 7
2 Solver的創建
在1中,Solver
的指針solver
是通過SolverRegistry::CreateSolver
創建的,CreateSolver
函數中值得注意帶是return registry[type](param)
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
其中:
registry
是一個map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry
其中Creator
是一個函數指針類型: typedef Solver<Dtype>* (*Creator)(const SolverParameter&)
registry[type]
爲一個函數指針變量,在Lenet5
中,此處具體的值爲 caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)
其中Creator_SGDSolver
在以下宏中定義, REGISTER_SOLVER_CLASS(SGD)
該宏完全展開得到的內容爲:
template <typename Dtype> \
Solver<Dtype>* Creator_SGDSolver( \
const SolverParameter& param) \
{ \
return new SGDSolver<Dtype>(param); \
} \
static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>); \
static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
從上可以看出,registry[type](param)
中實際上調用了SGDSolver
帶構造方法,事實上,網絡是在SGDSolver
的構造方法中初始化的。
SGDSolver
的定義如下:
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
......
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
SGDSolver
繼承與Solver<Dtype>
,因而new SGDSolver<Dtype>(param)
將執行Solver<Dtype>
的構造函數,然後調用自身構造函數。整個網絡帶初始化即在這裏面完成(詳見本系列博文(三))。
3 Solver::Solve()函數
在這個函數裏面,程序執行完網絡的完整訓練過程。
核心代碼如下:
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Step(param_.max_iter() - iter_);
//..
Snapshot();
//..
// some additional display
// ...
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
說明:
- 值得關注的代碼是
Step()
,在該函數中,值得了param_.max_iter()
輪迭代(10000) - 在Snapshot()中序列化model到文件
4 Solver::Step()函數
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
//10000輪迭代
while (iter_ < stop_iter) {
// 每隔500輪進行一次測試
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
// 測試網絡,實際是執行前向傳播計算loss
TestAll();
}
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
// 執行反向傳播,前向計算損失loss,並計算loss關於權值的偏導
loss += net_->ForwardBackward(bottom_vec);
}
// 平滑loss,計算結果用於輸出調試等
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
// 通過反向傳播計算的偏導更新權值
ApplyUpdate();
}
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
4.1 Solver::TestAll()函數
在TestAll()
中,調用Test(test_net_id)
對每個測試網絡test_net(不是訓練網絡train_net)進行測試。在Lenet中,只有一個測試網絡,所以只調用一次Test(0)
進行測試。
Test()函數裏面做了兩件事:
- 前向計算網絡,得到網絡損失,見 (Caffe,LeNet)前向計算(五)
- 通過測試網絡的第11層accuracy層,與第12層loss層結果統計accuracy與loss信息。
4.2 Net::ForwardBackward()函數
Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
Dtype loss;
Forward(bottom, &loss);
Backward();
return loss;
}
- 1
- 2
- 3
- 4
- 5
- 6
- 1
- 2
- 3
- 4
- 5
- 6
說明:
- 前向計算。計算網絡損失loss,參考 (Caffe,LeNet)前向計算(五)
- 反向傳播。計算loss關於網絡權值的偏導,參考 (Caffe,LeNet)反向傳播(六)
4.3 Solver::ApplyUpdate()函數
根據反向傳播階段計算的loss關於網絡權值的偏導,使用配置的學習策略,更新網絡權值從而完成本輪學習。詳見 (Caffe,LeNet)權值更新(七)
5 訓練完畢
至此,網絡訓練優化完成。在第3部分solve()函數中,最後對訓練網絡與測試網絡再執行一輪額外的前行計算求得loss,以進行測試。