基於OpenCV CxCore和Conjugate Gradient Method求函數局部極小值的抽象類

這是Matlab代碼的一個簡單翻譯,原作者有關信息參見:
http://www.kyb.tuebingen.mpg.de/bs/people/carl/code/minimize/

mlminimize.h
  1. #ifndef GUARD_mlminimize_h
  2. #define GUARD_mlminimize_h
  3. #include <ml.h>
  4. struct CV_EXPORTS CvMinimizeParams
  5. {
  6.     double INT;
  7.     double EXT;
  8.     int ITMAX;
  9.     double RATIO;
  10.     double SIG;
  11.     double RHO;
  12.     CvMinimizeParams()
  13.     : INT(.1),
  14.       EXT(3.),
  15.       ITMAX(20),
  16.       RATIO(10.),
  17.       SIG(.1),
  18.       RHO(.05)
  19.     {}
  20.     CvMinimizeParams( double _INT, double _EXT, int _ITMAX, double _RATIO, double _SIG, double _RHO )
  21.     : INT(_INT),
  22.       EXT(_EXT),
  23.       ITMAX(_ITMAX),
  24.       RATIO(_RATIO),
  25.       SIG(_SIG),
  26.       RHO(_RHO)
  27.     {}
  28. };
  29. class CV_EXPORTS CvMinimize
  30. {
  31.     private:
  32.         CvMinimizeParams* params;
  33.         virtual bool function( CvMat* x, double& result ) = 0;
  34.         virtual bool derivative( CvMat* x, CvMat* result ) = 0;
  35.     public:
  36.         CvMinimize( CvMinimizeParams* _params )
  37.         : params(_params)
  38.         {}
  39.         CvMat* minimize( CvMat* x, int length, double red = 1. );
  40. };
  41. #endif
mlminimize.cpp
  1. #include "mlminimize.h"
  2. CvMat* CvMinimize::minimize( CvMat* x,
  3.                  int length,
  4.                  double red )
  5. {
  6.     CvMat *df0, *df3, *dF0;
  7.     CvMat *s;
  8.     CvMat *X, *X0, *Xn;
  9.     X = cvCreateMat( x->rows, x->cols, CV_64FC1 );
  10.     cvCopy( x, X );
  11.     dF0 = cvCreateMat( X->rows, X->cols, CV_64FC1 );
  12.     df0 = cvCreateMat( X->rows, X->cols, CV_64FC1 );
  13.     df3 = cvCreateMat( X->rows, X->cols, CV_64FC1 );
  14.     s = cvCreateMat( X->rows, X->cols, CV_64FC1 );
  15.     X0 = cvCreateMat( X->rows, X->cols, CV_64FC1 );
  16.     Xn = cvCreateMat( X->rows, X->cols, CV_64FC1 );
  17.     cvZero( dF0 );
  18.     cvZero( df0 );
  19.     cvZero( df3 );
  20.     cvZero( s );
  21.     cvZero( X0 );
  22.     cvZero( Xn );
  23.     double F0 = 0, f0 = 0, f1 = 0, f2 = 0, f3 = 0, f4 = 0;
  24.     double x1 = 0, x2 = 0, x3 = 0, x4 = 0;
  25.     double d0 = 0, d1 = 0, d2 = 0, d3 = 0, d4 = 0;
  26.     double A = 0, B = 0;
  27.     bool ls_failed = 0;
  28.     function( X, f0 );
  29.     derivative( X, df0 );
  30.     cvSubRS( df0, cvScalar(0), s );
  31.     d0 = -cvDotProduct( s, s );
  32.     x3 = red/(1.-d0);
  33.     int i = 0;
  34.     int l = ( length > 0 ) ? length : -length;
  35.     int ls = ( length > 0 ) ? 1 : 0;
  36.     int eh = ( length > 0 ) ? 0 : 1;
  37.     while ( i < l )
  38.     {
  39.         i+=ls;
  40.         cvCopy( X, X0 );
  41.         F0 = f0;
  42.         cvCopy( df0, dF0 );
  43.         int m = ( length > 0 ) ? params->ITMAX : l-i;
  44.         if ( params->ITMAX < m )
  45.             m = params->ITMAX;
  46.         for ( ; ; )
  47.         {
  48.             x2 = 0;
  49.             f2 = f0;
  50.             d2 = d0;
  51.             f3 = f0;
  52.             cvCopy( df0, df3 );
  53.             while ( m > 0 )
  54.             {
  55.                 m--;
  56.                 i+=eh;
  57.                 cvScaleAdd( s, cvScalar(x3), X, Xn );
  58.                 cvZero( df3 );
  59.                 if ((function( Xn, f3 ))&&(derivative( Xn, df3 )))
  60.                     break;
  61.                 else
  62.                     x3 = (x2+x3)*.5;
  63.             }
  64.             if ( f3 < F0 )
  65.             {
  66.                 cvCopy( Xn, X0 );
  67.                 F0 = f3;
  68.                 cvCopy( df3, dF0 );
  69.             }
  70.             if ( (d3 > params->SIG*d0)||(f3 > f0+x3*params->RHO*d0)||(m <= 0) )
  71.                 break;
  72.             x1 = x2;
  73.             f1 = f2;
  74.             d1 = d2;
  75.             x2 = x3;
  76.             f2 = f3;
  77.             d2 = d3;
  78.             A = 6.*(f1-f2)+3.*(d2+d1)*(x2-x1);
  79.             B = 3.*(f2-f1)-(2.*d1+d2)*(x2-x1);
  80.             x3 = B*B-A*d1*(x2-x1);
  81.             if ( x3 < 0 )
  82.                 x3 = x2*params->EXT;
  83.             else {
  84.                 x3 = x1-d1*(x2-x1)*(x2-x1)/(B+sqrt(x3));
  85.                 if ( x3 < 0 )
  86.                     x3 = x2*params->EXT;
  87.                 else {
  88.                     if ( x3 > x2*params->EXT )
  89.                         x3 = x2*params->EXT;
  90.                     else if ( x3 < x2+params->INT*(x2-x1) )
  91.                         x3 = x2+params->INT*(x2-x1);
  92.                 }
  93.             }
  94.         }
  95.         while ( ((fabs(d3) > -params->SIG*d0)||(f3 > f0+x3*params->RHO*d0 ))&&(m > 0) )
  96.         {
  97.             if ( (d3 > 1e-8)||(f3 > f0+x3*params->RHO*d0) )
  98.             {
  99.                 x4 = x3;
  100.                 f4 = f3;
  101.                 d4 = d3;
  102.             } else {
  103.                 x2 = x3;
  104.                 f2 = f3;
  105.                 d2 = d3;
  106.             }
  107.             if ( f4 > f0 )
  108.                 x3 = x2-(.5*d2*(x4-x2)*(x4-x2))/(f4-f2-d2*(x4-x2));
  109.             else {
  110.                 A = 6.*(f2-f4)/(x4-x2)+3.*(d4+d2);
  111.                 B = 3.*(f4-f2)-(2.*d2+d4)*(x4-x2);
  112.                 x3 = B*B-A*d2*(x4-x2)*(x4-x2);
  113.                 if ( x3 < 0 )
  114.                     x3 = (x2+x4)*.5;
  115.                 else
  116.                     x3 = x2+(sqrt(x3)-B)/A;
  117.             }
  118.             if ( x3 > x4-params->INT*(x4-x2) )
  119.                 x3 = x4-params->INT*(x4-x2);
  120.             if ( x3 < x2+params->INT*(x4-x2) )
  121.                 x3 = x2+params->INT*(x4-x2);
  122.             cvScaleAdd( s, cvScalar(x3), X, Xn );
  123.             function( Xn, f3 );
  124.             cvZero( df3 );
  125.             derivative( Xn, df3 );
  126.             if ( f3 < F0 )
  127.             {
  128.                 cvCopy( Xn, X0 );
  129.                 F0 = f3;
  130.                 cvCopy( df3, dF0 );
  131.             }
  132.             m--;
  133.             i+=eh;
  134.             d3 = cvDotProduct( df3, s );
  135.         }
  136.         if ( (fabs(d3) < -params->SIG*d0)&&(f3 < f0+x3*params->RHO*d0) )
  137.         {
  138.             cvCopy( Xn, X );
  139.             f0 = f3;
  140.             cvScaleAdd( s, cvScalar((cvDotProduct( df0, df3 )-cvDotProduct( df3, df3 ))/cvDotProduct( df0, df0 )), df3, s );
  141.             cvSubRS( s, cvScalar(0), s );
  142.             cvCopy( df3, df0 );
  143.             d3 = d0;
  144.             d0 = cvDotProduct( df0, s );
  145.             if ( d0 > 0 )
  146.             {
  147.                 cvSubRS( df0, cvScalar(0), s );
  148.                 d0 = -cvDotProduct( s, s );
  149.             }
  150.             x3 = x3*(params->RATIO < d3/(d0-1e-8) ? params->RATIO:d3/(d0-1e-8));
  151.             ls_failed = 0;
  152.         } else {
  153.             cvCopy( X0, X );
  154.             f0 = F0;
  155.             cvCopy( dF0, df0 );
  156.             if ( ls_failed )
  157.                 break;
  158.             cvSubRS( df0, cvScalar(0), s );
  159.             d0 = -cvDotProduct( s, s );
  160.             x3 = red/(1.-d0);
  161.             ls_failed = 1;
  162.         }
  163.     }
  164.     cvReleaseMat( &s );
  165.     cvReleaseMat( &X0 );
  166.     cvReleaseMat( &Xn );
  167.     cvReleaseMat( &dF0 );
  168.     cvReleaseMat( &df0 );
  169.     cvReleaseMat( &df3 );
  170.     return X;
  171. }
使用樣例
  1. #include "mlminimize.h"
  2. #include <iostream>
  3. class CV_EXPORTS Rosenbrock : public CvMinimize
  4. {
  5.     private:
  6.         virtual bool function( CvMat* x, double& result )
  7.         {
  8.             double* x_vec = x->data.db;
  9.             result = 0;
  10.             for ( int i = 0; i < x->cols-1; i++ )
  11.                 result += 100*(x_vec[i+1]-x_vec[i]*x_vec[i])*(x_vec[i+1]-x_vec[i]*x_vec[i])+(1-x_vec[i])*(1-x_vec[i]);
  12.             return 1;
  13.         }
  14.         virtual bool derivative( CvMat* x, CvMat* result )
  15.         {
  16.             double* x_vec = x->data.db;
  17.             double* result_vec = result->data.db;
  18.             for ( int i = 0; i < x->cols-1; i++ )
  19.                 result_vec[i] = -400*x_vec[i]*(x_vec[i+1]-x_vec[i]*x_vec[i])-2*(1-x_vec[i]);
  20.             for ( int i = 1; i < x->cols; i++ )
  21.                 result_vec[i] += 200*(x_vec[i]-x_vec[i-1]*x_vec[i-1]);
  22.             return 1;
  23.         }
  24.     public:
  25.         Rosenbrock( CvMinimizeParams* params )
  26.         : CvMinimize(params) {}
  27.         bool derivative1( CvMat* x, CvMat* result )
  28.         {
  29.             double* x_vec = x->data.db;
  30.             double* result_vec = result->data.db;
  31.             for ( int i = 0; i < x->cols-1; i++ )
  32.                 result_vec[i] = -400*x_vec[i]*(x_vec[i+1]-x_vec[i]*x_vec[i])-2*(1-x_vec[i]);
  33.             for ( int i = 1; i < x->cols; i++ )
  34.                 result_vec[i] += 200*(x_vec[i]-x_vec[i-1]*x_vec[i-1]);
  35.             return 1;
  36.         }
  37. };
  38. int main()
  39. {
  40.     CvMinimizeParams* params = new CvMinimizeParams();
  41.     Rosenbrock* rb = new Rosenbrock( params );
  42.     CvMat* X = cvCreateMat( 1, 2, CV_64FC1 );
  43.     cvZero( X );
  44.     CvMat* result = rb->minimize( X, 25 );
  45.     for ( int k = 0; k < 2; k++ )
  46.         printf("%f/n", result->data.db[k]);
  47. }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章