adaboost訓練弱分類器的原理見上一個博客::http://blog.csdn.net/lanxuecc/article/details/52681525
opencv中adaboost訓練弱分類器的主體代碼是函數cvCreateCARTClassifier,這個函數通過大致邏輯是:
1、通過調用訓練結點函數cvCreateMTStumpClassifier來創建根結點
2、在要求弱分類器特徵不只一個的情況下,通過分裂結點來增加新的特徵形成CART樹的弱分類器。
源碼及註釋如下
CV_BOOST_IMPL
CvClassifier* cvCreateCARTClassifier( CvMat* trainData, //預計算的訓練樣本每個特徵的值矩陣
int flags, //1表示樣本按行排列,0表示樣本按行排列
CvMat* trainClasses, //訓練樣本類別向量,如果是正樣本標識爲1,負樣本標識爲-1
CvMat* typeMask, //爲了便於回調函數而統一格式的變量
CvMat* missedMeasurementsMask, //同上
CvMat* compIdx, //特徵序列向量
CvMat* sampleIdx, //樣本序列向量
CvMat* weights, //樣本權值向量
CvClassifierTrainParams* trainParams ) //傳入一些弱分類器所需的參數比如需要幾個特徵,和一些需用的分類函數指針
{
CvCARTClassifier* cart = NULL;//CART樹狀弱分類器
size_t datasize = 0;
int count = 0; // CART中的節點數目
int i = 0;
int j = 0;
CvCARTNode* intnode = NULL; // CART節點
CvCARTNode* list = NULL; // 候選節點鏈表
int listcount = 0; // 候選節點個數
CvMat* lidx = NULL; // 左子節點樣本序列
CvMat* ridx = NULL; // 右子節點樣本序列
float maxerrdrop = 0.0F;
int idx = 0;
//定義節點分裂函數指針 這個函數指針指向的是函數icvSplitIndicesCallback
void (*splitIdxCallback)( int compidx, float threshold,
CvMat* idx, CvMat** left, CvMat** right,
void* userdata );
void* userdata;
//設置非葉子節點個數
count = ((CvCARTTrainParams*) trainParams)->count; /*弱分類器的特徵個數,一般都只有一個*/
assert( count > 0 );
/*分配一個弱分類器的內存空間*/
datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count +
sizeof( float ) * (count + 1);
cart = (CvCARTClassifier*) cvAlloc( datasize );
memset( cart, 0, datasize );
/*初始化弱分類器*/
cart->count = count;
cart->eval = cvEvalCARTClassifier; /*弱分類器使用函數*/
cart->save = NULL;
cart->release = cvReleaseCARTClassifier; /*弱分類器內存釋放函數 */
cart->compidx = (int*) (cart + 1); //非葉子節點的最優Haar特徵序號
cart->threshold = (float*) (cart->compidx + count); //非葉子節點的最優Haar特徵閾值
cart->left = (int*) (cart->threshold + count); //左子節點序號,包含葉子節點序號
cart->right = (int*) (cart->left + count); //右子節點序號,包含葉子節點序號
cart->val = (float*) (cart->right + count); //葉子節點輸出置信度數組
datasize = sizeof( CvCARTNode ) * (count + count);
intnode = (CvCARTNode*) cvAlloc( datasize );
memset( intnode, 0, datasize );
list = (CvCARTNode*) (intnode + count);
//節點分裂函數指針,一般爲icvSplitIndicesCallback函數
splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
userdata = ((CvCARTTrainParams*) trainParams)->userdata;
if( splitIdxCallback == NULL )//如果沒有用默認的節點分裂函數
{
splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;//R代表樣本按行排列,C代表樣本按列排列
userdata = trainData;
}
/* create root of the tree */
//創建CART弱分類器的根節點,如果該弱分類器只有一個特徵,那這裏就創建了弱分類器,不用後面作結點分裂
//stumpConstructor是一個函數指針,他指向cvCreateMTStumpClassifier函數,所以這裏調用的是這個函數
intnode[0].sampleIdx = sampleIdx;
intnode[0].stump = (CvStumpClassifier*)
((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
((CvCARTTrainParams*) trainParams)->stumpTrainParams );
cart->left[0] = cart->right[0] = 0;
/* build tree */
//創建樹狀弱分類器,lerror或者rerror不爲0代表着當前節點爲非葉子節點
listcount = 0;
for( i = 1; i < count; i++ )/*當弱分類器只有一個特徵也就是隻一個非葉子結點時,不會走入這個分支*/
{
/* split last added node */
/*這個函數的作用就是:::基於當前結點的閾值將樣本分類,
分類爲負樣本的樣本存儲在lidx中,分類爲正樣本的樣本存儲在ridx,
後續從當前結點左分支分裂時,用lidx樣本來訓練一個結點,
從當前結點右分支分裂時,用ridx樣本來訓練一個結點*/
splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
intnode[i-1].sampleIdx, &lidx, &ridx, userdata );
//爲分裂之後的非葉子節點計算最優特徵
if( intnode[i-1].stump->lerror != 0.0F )
{
//小於閾值的樣本集合,就是當前結點的左分支結點的訓練
list[listcount].sampleIdx = lidx;
//基於新樣本集合尋找最優特徵,重複調用訓練樁的函數來訓練
list[listcount].stump = (CvStumpClassifier*)
((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
trainClasses, typeMask, missedMeasurementsMask, compIdx,
list[listcount].sampleIdx,
weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
//計算信息增益(這裏是error的下降程度)
list[listcount].errdrop = intnode[i-1].stump->lerror
- (list[listcount].stump->lerror + list[listcount].stump->rerror);
list[listcount].leftflag = 1;
list[listcount].parent = i-1;
listcount++;
}
else
{
cvReleaseMat( &lidx );
}
//同上,左分支換成右分支,偏向於右分支
if( intnode[i-1].stump->rerror != 0.0F )
{
list[listcount].sampleIdx = ridx;
list[listcount].stump = (CvStumpClassifier*)
((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
trainClasses, typeMask, missedMeasurementsMask, compIdx,
list[listcount].sampleIdx,
weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
list[listcount].errdrop = intnode[i-1].stump->rerror
- (list[listcount].stump->lerror + list[listcount].stump->rerror);
list[listcount].leftflag = 0;//標識訓練出來的節點是當前結點左分支結點還是右還是右分支結點
list[listcount].parent = i-1;
listcount++;
}
else
{
cvReleaseMat( &ridx );
}
if( listcount == 0 ) break;
/*find the best node to be added to the tree*/
/*找到已經分裂得到的所有結點中,使分類誤差下降最快的那個結點,
把它加入到CART樹中去,構成弱分類器的一部分*/
idx = 0;
maxerrdrop = list[idx].errdrop;
for( j = 1; j < listcount; j++ )
{
if( list[j].errdrop > maxerrdrop )
{
idx = j;
maxerrdrop = list[j].errdrop;
}
}
//確定誤差下降最快的結點應該加入到CART樹中的位置
intnode[i] = list[idx];
if( list[idx].leftflag )
{
cart->left[list[idx].parent] = i;
}
else
{
cart->right[list[idx].parent] = i;
}
//將被選中放入CART樹的結點刪除
if( idx != (listcount - 1) )
{
list[idx] = list[listcount - 1];
}
listcount--;
}
/* fill <cart> fields */
// 這段代碼用於確定樹中節點最優特徵序號、閾值與葉子節點序號和輸出置信度
// left與right大於等於0,爲0代表葉子節點
// 就算CART中只有一個節點,仍舊需要設置葉子節點
j = 0;
cart->count = 0;
for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
{
cart->count++;
cart->compidx[i] = intnode[i].stump->compidx;
cart->threshold[i] = intnode[i].stump->threshold;
/* leaves */
if( cart->left[i] <= 0 )//確定葉子序號與葉子的輸出置信度
{
cart->left[i] = -j;
cart->val[j] = intnode[i].stump->left;//這個left是float值,不是CVMat*
j++;
}
if( cart->right[i] <= 0 )
{
cart->right[i] = -j;
cart->val[j] = intnode[i].stump->right;
j++;
}
}
/* CLEAN UP *//*一些臨時用的內存釋放*/
for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
{
intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
if( i != 0 )
{
cvReleaseMat( &(intnode[i].sampleIdx) );
}
}
for( i = 0; i < listcount; i++ )
{
list[i].stump->release( (CvClassifier**) &(list[i].stump) );
cvReleaseMat( &(list[i].sampleIdx) );
}
cvFree( &intnode );
return (CvClassifier*) cart; /*返回創建的弱分類器*/
}