直線擬合、二次曲線擬合、折線擬合和KNN近鄰(附代碼)

一個工程中的應用,需要對一組數據做上面四種形式的擬合迴歸,並且根據模型對輸入做evaluation,就是做一個函數曲線擬合。

下面的RevPre定義了方法和結構,Util是使用案例,其中的Opt的type指示模型要擬合的是哪一種。


1)直線擬合: y = kx+b

2)   二次曲線擬合:y = AX^2 + BX + C

以上兩種很典型,不多解釋;

3) KNN

KNN的Opt階段什麼也沒做,只是保存了所有數據對,然後檢測階段選取最近的K(MIN_KNN_K < K < MAX_KNN_K)個點來計算距離加權後的結果。

4) 折線擬合:

前面先介紹KNN,是因爲折線擬合的結果與之很類似,同樣保存了一系列的數據對,檢測的時候判斷處於哪一段。不同之處是做了簡化,而不是全部存儲起來。

首先對數據對按照x快排,然後以直線擬合判斷總體是升序還是降序,接下來就是不斷淘汰不符合順序的點,發生“突起”或者“凹陷”時清除與擬合直線距離遠的點,直到所有點都按序排列,形成一條單調折線。


代碼:

RevPre.h

#include <iostream>
#include <math.h>

/************************Model type*********************************/
#define MAX_POL_DEPTH	3
#define MAX_KNN_K			10
#define MIN_KNN_K				1

#define MAX(x,y)		(x) > (y) ? (x) : (y)
#define MIN(x,y)			(x) < (y) ? (x) : (y)
enum modelType{
	StraightLine = 0, // default
	CurveAt2,
	BrokenLine,
	KNNModel
};
typedef struct Model{
	enum modelType type;
	// Line parameters
	double			lineParam[MAX_POL_DEPTH];
	// Point model
	double			*px, *py;
	int				len;
};
Model* CreateModel();
void ReleaseModel( Model** _ptr );
bool SetOptData( Model* ptr, double *x, double *y, int len );
bool Opt( Model *ptr, modelType type );
double Predict( Model *ptr, double x );
/**************************Polynomial*******************************/
/* Internal */
void CalculatePower(double *powers, int ptNum, int maxDepth, double *x );          //將初始x[i]的值的各冪次方存儲在一個二維數組裏面 
void CalculateParams(double *powers, int ptNum, int maxDepth, 
					 double *params, double *y);																	//計算正規方程組的係數矩陣 
void DirectLU( double *params, int ptNum, int maxDepth, double *x );						//列主元LU分解
inline void swap(double &,double &);																		//交換兩個變量的值
/* External */
bool PolynomialOpt( Model *ptr );
/************************StraightLine********************************/
bool StraightLineOpt( Model *ptr );
/************************BrokenLine********************************/
/*Internal*/
int SingleSort( double *index, double *context, int start, int end );
void QuickSort( double *index, double *context, int start, int end );
int CheckSequence( double *context, int start, int end, bool upTrend );
/*External*/
bool BrokenLineOpt( Model *ptr );
/********************KNN(Lazy-learning)****************************/
bool KNNOpt( Model *ptr );

RevPre.cpp

#include "Revise.h"

Model* CreateModel(){
	Model *ptr = new Model;
	ptr->type = StraightLine;

	ptr->px = ptr->py = NULL;
	ptr->len = 0;
	memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );

	return ptr;
}

void ReleaseModel( Model** _ptr ){
	Model *ptr = *_ptr;
	if ( ptr->px ) delete[] ptr->px;
	if ( ptr->py ) delete[] ptr->py;
	delete ptr;

	*_ptr = NULL;
	return ;
}

bool SetOptData( Model *ptr, double *x, double *y, int len ){
	if ( !ptr || !x || !y ) return false;

	if ( !ptr->px ) ptr->px = new double[len];
	if ( !ptr->py ) ptr->py = new double[len];

	ptr->len = len;
	memcpy( ptr->px, x, sizeof(double)*len );
	memcpy( ptr->py, y, sizeof(double)*len );
	return true;
}

bool Opt( Model *ptr, modelType type ){
	if ( !ptr ) return false;
	switch( type )
	{
	case StraightLine:
		return	StraightLineOpt( ptr );
	case CurveAt2:
		return	PolynomialOpt( ptr );
	case BrokenLine:
		return	BrokenLineOpt( ptr );
	case KNNModel:
		return	KNNOpt( ptr );
	default:
		return	false;
	}
}

double Predict( Model *ptr, double x ){
	if ( !ptr ) exit (-1);
	switch( ptr->type )
	{
	case StraightLine:
		return ptr->lineParam[0] + ptr->lineParam[1]*x;
	case CurveAt2:
		return ptr->lineParam[0] + ptr->lineParam[1]*x + ptr->lineParam[2]*x*x;
	case BrokenLine:
		{
			if ( ptr->len < 3 ) exit(-2);
			int first = 0;
			if ( x <= ptr->px[0] )
			{
				double x0 = ptr->px[0], x1 = ptr->px[1];
				double y0 = ptr->py[0], y1 = ptr->py[1];
				return y0 - (x0-x)*(y1-y0)/(x1-x0);
			}
			else if ( x >= ptr->px[ptr->len-1] )
			{
				double x0 = ptr->px[ptr->len-2], y0 = ptr->py[ptr->len - 2];
				double x1 = ptr->px[ptr->len-1], y1 = ptr->py[ptr->len - 1];
				return y1 -(x-x1)*(y0-y1)/(x1-x0);
			}
			else
			{
				while ( ptr->px[first] < x ) { first ++ ;}
				first --;
				double deltay = ptr->py[first+1] - ptr->py[first];
				double deltax = ptr->px[first+1] - ptr->px[first];
				return ptr->py[first] + deltay*(x-ptr->px[first])/deltax;
			}
		}
	case KNNModel:
		{
			int K = MAX( MIN_KNN_K, MIN( int(ptr->len*0.1), MAX_KNN_K ) );
			// Prepare the initial K neighbours
			double *dist_team = new double[K];
			int		*idx_team = new int[K];
			int		farestIdt = -1;
			double	farestDist = 0;

			int id = 0;
			for ( ; id < K; id ++ )
			{
				idx_team[id] = id;
				dist_team[id] = abs( ptr->px[id] - x );
				if ( farestDist <= dist_team[id] )
				{
					farestIdt = id;
					farestDist = dist_team[id];
				}
			}
			// Looking for the K nearest neighbours
			while ( id < ptr->len )
			{
				if ( abs( ptr->px[id] -x ) < farestDist )
				{
					// Update the team
					idx_team[farestIdt] = id;
					dist_team[farestIdt] = abs( ptr->px[id] - x );
					// Update the farest record
					farestIdt = 0;
					farestDist = dist_team[0];
					for ( int searchIdt = 1; searchIdt < K; searchIdt ++ )
					{
						if ( dist_team[searchIdt] > farestDist )
						{
							farestDist = dist_team[searchIdt];
							farestIdt = searchIdt;
						}
					}
				}
				id ++;
			}

			// Calculate their contribution
			double res = 0.0;
			double weightSum = 0.0;
			for ( int seachIdt = 0; seachIdt < K; seachIdt ++ )
			{
				weightSum += 1.0/dist_team[seachIdt];
				res += 1.0/dist_team[seachIdt]*ptr->py[idx_team[seachIdt]];
			}
			delete[] dist_team;
			delete[] idx_team;
			return res/weightSum;
		}
	default:
		exit(-2);
	}
}
/**************************Polynomial*******************************/
bool StraightLineOpt( Model *ptr )
{
	if ( !ptr ) return false; 
	if ( !ptr->px || !ptr->py ) return false;

	int outLen = 2;
	int ptNum = ptr->len, maxDepth = outLen;

	double *powers = new double[maxDepth*ptNum];
	double *params = new double[maxDepth*(maxDepth+1)];

	CalculatePower( powers, ptNum, maxDepth, ptr->px );

	CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //計算正規方程組的係數矩陣

	DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解
	ptr->type = StraightLine;

	std::cout<<"-------------------------"<<std::endl;
	std::cout<<"擬合函數的係數分別爲:\n";
	for( int i=0;i<maxDepth;i++)
		std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;
	std::cout<<"-------------------------"<<std::endl;

	delete[] powers;
	delete[] params;

	return true;
}
bool PolynomialOpt( Model *ptr )
{
	if ( !ptr ) return false; 
	if ( !ptr->px || !ptr->py ) return false;

	int outLen = MAX_POL_DEPTH;
	int ptNum = ptr->len, maxDepth = outLen;

	double *powers = new double[maxDepth*ptNum];
	double *params = new double[maxDepth*(maxDepth+1)];

	CalculatePower( powers, ptNum, maxDepth, ptr->px );

	CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //計算正規方程組的係數矩陣

	DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解
	ptr->type = CurveAt2;

	/*std::cout<<"-------------------------"<<std::endl;
	std::cout<<"擬合函數的係數分別爲:\n";
	for( int i=0;i<maxDepth;i++)
		std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;
	std::cout<<"-------------------------"<<std::endl;*/

	delete[] powers;
	delete[] params;

	return true;
}

void CalculatePower(double *powers, int ptNum, int maxDepth, double *x )
{
	if ( !powers || !x ) return ;

	int			i, j, k;
	double		temp;

	for( i = 0; i < maxDepth; i ++ )
		for( j = 0; j < ptNum; j ++ )
		{
			temp = 1;
			for( k = 0; k < i; k ++ )
				temp *= x[j];
			powers[i*ptNum+j] = temp;
		}
	return ;
}

void CalculateParams(double *powers, int ptNum, int maxDepth, 
					 double *params, double *y)
{
	if ( !powers || !params || !y ) return ;

	int			i, j, k;
	double		temp;
	int			step = maxDepth + 1;

	for( i = 0; i < maxDepth; i ++ )
	{
		for(j = 0; j < maxDepth; j ++ )
		{
			temp = 0;
			for( k = 0; k < ptNum; k ++ )
				temp += powers[i*ptNum+k]*powers[j*ptNum+k];
			params[i*step+j] = temp;
		}

		temp = 0;
		for( k = 0; k < ptNum; k ++ )
		{
			temp += y[k]*powers[i*ptNum+k];
			params[i*step+maxDepth] = temp;
		}
	}

	return ;
}

inline void swap(double &a,double &b)
{
	a=a+b;
	b=a-b;
	a=a-b;
}

void DirectLU( double *params, int ptNum, int maxDepth, double *x )
{
	int				i, r, k, j;
	double			max;
	int				step = maxDepth + 1;

	double *s = new double[maxDepth];
	double *t = new double[maxDepth];
	// choose the main element
	for( r = 0; r < maxDepth; r ++ )
	{
		max = 0;
		j = r;
		for( i = r; i < maxDepth; i ++ ) 
		{
			s[i] = params[i*step+r];
			for( k = 0; k < r; k ++ )
				s[i] -= params[i*step+k] * params[k*step+r];
			s[i] = abs(s[i]);

			if( s[i] > max ){
				j = i;
				max = s[i];
			}
		}
		// if the "main"element is not @ row r, swap the corresponding element 
		if( j != r ) 
		{
			for( i = 0; i < maxDepth + 1; i ++ )
				swap( params[r*step+i], params[j*step+i] );
		}
		for( i = r; i < step; i ++ ) 
			for( k = 0; k < r; k ++ ){
				params[r*step+i] -= params[r*step+k] * params[k*step+i];
			}
			for(i = r+1; i < maxDepth; i ++ ) 
			{
				for ( k = 0; k < r; k ++ )
					params[i*step+r] -= params[i*step+k] * params[k*step+r];
				params[i*step+r] /= params[r*step+r];
			}
	}
	for( i = 0; i < maxDepth; i ++ )
		t[i] = params[ i*step + maxDepth ];
	for ( i = maxDepth - 1; i >= 0; i -- ) //利用回代法求最終解
	{
		for ( r = maxDepth - 1; r > i; r -- )
			t[i] -= params[ i*step + r ] * x[r];
		x[i] = t[i]/params[i*step+i];
	}

	delete[] s;
	delete[] t;

	return ;
}

/**********************Broken Line***************************/
// Quick Sort
int SingleSort( double *index, double *context, int start, int end )
{
	if ( end - start < 1 ) return start;

	int i = start, j = end;
	double key = index[i];
	double key_ = context[i];
	while ( i < j )
	{
		while ( index[j] > key && j > i ) j --;
		if ( index[j] < key )
		{
			index[i] = index[j];
			context[i] = context[j];
		}
		while ( index[i] < key && j > i ) i ++;
		if ( index[i] > key )
		{
			index[j] = index[i];
			context[j] = context[i];
		}
	}
	
	index[i] = key;
	context[j] = key_;

	return i;
}

void QuickSort( double *index, double *context, int start, int end )
{
	if ( end - start < 1 ) return ; // important
	int mid = SingleSort( index, context, start, end );
	QuickSort( index, context, start, mid - 1 );
	QuickSort( index, context, mid+ 1, end );
}
int CheckSequence( double *context, int start, int end, bool upTrend )
{
	int i = start;
	for ( ; i < end; i ++ )
	{
		if ( upTrend && context[i+1] < context[i] )
		{
			return i;
		}
		if ( !upTrend && context[i] < context[i+1] )
		{
			return i;
		}
	}
	return -1;
}
// Form the broken line
bool BrokenLineOpt( Model *ptr )
{
	if ( !ptr ) return false;
	if ( !ptr->len || !ptr->px || !ptr->py ) return false;

	// analyse the trend of points and get its approximate line
	StraightLineOpt( ptr );
	double k = ptr->lineParam[1], b = ptr->lineParam[0];
	bool upTrend = ( k > 0 );
	// sort the sequence by py
	QuickSort( ptr->px, ptr->py, 0, ptr->len - 1 );

	int oddPoint = 0;
	while ( (oddPoint = CheckSequence( ptr->py, oddPoint, ptr->len -1, upTrend ) ) != -1 )
	{
		double formerErr = abs( k*ptr->px[oddPoint] + b - ptr->py[oddPoint] );
		double laterErr = abs( k*ptr->px[oddPoint+1] + b - ptr->py[oddPoint+1] );
		oddPoint = formerErr > laterErr ? oddPoint : oddPoint + 1;
		// remove the odd point
		memcpy( ptr->py + oddPoint, ptr->py + oddPoint + 1, sizeof(double) );
		memcpy( ptr->px + oddPoint, ptr->px + oddPoint + 1, sizeof(double) );
		ptr->len --;
		oddPoint --;
	}
	memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
	ptr->type = BrokenLine;
	return true;
}

/**********************Lazy Learning***************************/
bool KNNOpt( Model *ptr )
{
	// We do nothing as we say it's a lazy-learning method
	// Only when predict() is called, the learning process is invoked
	memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
	ptr->type = KNNModel;
	return true;
}

Util.cpp

#include "Revise.h"

int _tmain(int argc, _TCHAR* argv[])
{
	double x[10], y[10];
	for ( int i = 0; i < 10; i ++ )
	{
		x[i] = 10 - i;
		y[i] = (i-2)*(i-2);
	}

	Model *model = CreateModel();
	SetOptData( model, x, y, 10 );
	Opt( model, BrokenLine );
	double result = Predict( model, 1.5 );
	ReleaseModel( &model );

	return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章