圖像卷積的fft實現驗證(python)

1、Caffe的卷積操作時間主要在矩陣乘法,假設一個m*n卷積核,且輸入通道數爲1,輸出特徵圖大小爲h*w,則乘法個數m*n*h*w,這裏的優化僅限於對矩陣的乘法優化,因此,只要選擇適合的矩陣計算庫就可以了。

2、若使用FFT來計算圖像卷積。其主要步驟如下。

假設輸入圖像的大小爲len=h*w,卷積核大小k_len=m*n;通常len>>k_len;

  1. 對輸入圖像A做FFT,其算法的時間複雜度爲o(lenlglen);

  2. 對卷積核B做FFT,但是由於卷積核與輸入圖像尺寸不一樣,需要將卷積核擴展,即將卷積核倒置後,補len-k_len個0。

  3. 將A、B傅里葉變換的結果相乘,即對應位相乘獲得結果C。乘法個數爲len,時間複雜度爲o(len)

  4. 對C做IFFT,得到結果D,在D中每隔k_len的值實部取出來,就是圖像卷積的結果。因爲圖像卷積其實就是對應位相乘,所以需要每隔k_len取值,時間複雜度爲o(len)

假設卷積核個數>1,需要對卷積核重新擴展後重復2)3)4)步驟,並與上一個卷積核圖像卷積的值對應位相加就能獲得。

驗證正確性:

輸入圖像的卷積順序-1,1,3,8 2,43,1,3 1,3,1,3 54,-2,-3,-1

卷積核1,2,-1,3

圖像卷積結果:22,96,15,50.

此處注意,計算結果是用卷積核與輸入圖像進行乘法累加。

使用FFT結果:22,96,15,50。

此處注意,將卷積核逆置。

結果正確

 

 

暫時未能看出此方法的優越性,從空間上看,需要對卷積核進行擴展,其空間大小與輸入圖像的尺寸大小一樣。時間上分析,僅二者FFT對應相乘的時間的乘法個數就和矩陣乘法個數的數量級是一樣的了(當len>>k_len時)。適用於卷積核尺寸較大的情景。

圖2出處:https://core.ac.uk/download/pdf/24989291.pdf

雖然沒理解爲什麼性能提升這麼多,但是該論文所做的實驗證明了FFT性能很好,雖然複雜度推導跟我推導的差不多~此論文用了cuFFT庫

 

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

至於多通道,多卷積核的,經python驗證,卷積結果正確,其大概步驟如下:

水平有限,出錯之處,希望得到各位的指正。

--------------------------------找到當時寫的fft圖像卷積,忘了對不對了---------------------------------------------------------------------------

template <typename Dtype>
void BaseConvolutionLayer<Dtype>::forward_cpu_gemm(const Dtype* input,
    const Dtype* weights, Dtype* output, bool skip_im2col) {
    const Dtype* col_buff = input;
    //get an array from a channel imag
    //get an array from a kernel,kernel_h*kernel_w=N;
    //make fft and ifft,res=ifft(),res[N,2N,3N....]is the channel conv result 
    //get another channel imag,and repeat
    //output=sum(conv result)
  const Dtype * data_im = input;
  const int channels = conv_in_channels_;
  const int height = conv_input_shape_.cpu_data()[1];
  const int width = conv_input_shape_.cpu_data()[2];
  const int kernel_h = kernel_shape_.cpu_data()[0];
  const int kernel_w = kernel_shape_.cpu_data()[1];
  const int pad_h = pad_.cpu_data()[0];
  const int pad_w = pad_.cpu_data()[1];
  const int stride_h = stride_.cpu_data()[0];
  const int stride_w = stride_.cpu_data()[1];
  const int dilation_h = dilation_.cpu_data()[0];
  const int dilation_w = dilation_.cpu_data()[1];
  const int output_h = (height + 2 * pad_h -
    (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
  const int output_w = (width + 2 * pad_w -
    (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
  const int channel_size = height * width;
  const int kernel_size = kernel_h * kernel_w;
  const int out_size = output_h * output_w;
  

    //only for test
   /*int channels=1;
   int channel_size=9;
   int out_size=4;
   int kernel_size=4;
   Dtype my_test[]={2,3,1,7,6,3,2,1,9};
   Dtype *my_test_p=my_test;
   Dtype my_kernel[]={4,2,1,0};
   int kernel_h=2;
   int kernel_w=2;
   int imag_w=3;*/
   int sum_k=0;
   int count_i=0;
    //make input
    //    std::cout<<"data_in_fft"<<*(input+channel_size)<<std::endl;
    memset(output,0,out_size);
    //std::cout<<"channels="<<channels<<std::endl;
    //for(int k_l = 0;k_l < conv_out_channels;k_l++;)
    //{
        for(int c = 0;c < channels ; c++)
        {
          //memset(data_in_fft,0,channel_size);
          //memset(data_kernel,0,channel_size);
          Dtype data_in_fft[out_size*kernel_size] = {0};
          Dtype data_kernel[out_size*kernel_size] = {0};
        for(int count=0;count<out_size;count++)
        {   //memcpy(data_in_fft+count*kernel_size,input+(channel_size*c)+count,kernel_size*sizeof(Dtype));
          if(count_i==width-kernel_w+1)
                        {
                            count_i=0;
                          data_im+=kernel_w-1;
                        }
          count_i++;
          for(int kk=0;kk<kernel_h;kk++)
                    {
            
                      memcpy(data_in_fft+sum_k,data_im+width*kk,sizeof(Dtype)*kernel_w);
            sum_k+=kernel_w;
            /*for(int jj=0;jj<out_size*kernel_size;jj++)
                       std::cout<<count+kk<<"th"<<*(data_in_fft+jj)<<std::endl;
            std::cout<<"--------------------------"<<std::endl;*/
          }
          data_im++;
        }
      /*for(int jj=0;jj<out_size*kernel_size;jj++)
                std::cout<<*(data_in_fft+jj)<<std::endl;
      std::cout<<"--------------------------"<<std::endl;*/
            cv::Mat a(1, out_size*kernel_size, CV_32F, data_in_fft);
            //memcpy(data_kernel,this->blobs_[0]->mutable_cpu_data()+(kernel_size*c),kernel_size*sizeof(Dtype));
      memcpy(data_kernel,this->blobs_[0]->mutable_cpu_data()+(kernel_size*c),kernel_size*sizeof(Dtype));
      //std::cout<<"data_kernel"<<*data_kernel<<std::endl;

            const Dtype* weight = this->blobs_[0]->cpu_data();

            for(int j = 0;j < (kernel_size>>1);j++)
            {
                int tmp;
                tmp=data_kernel[j];
                data_kernel[j]=data_kernel[kernel_size-j-1];
                data_kernel[kernel_size-j-1]=tmp;

            }   
            /*for(int jj=0;jj<out_size*kernel_size;jj++)
            std::cout<<*(data_kernel+jj)<<std::endl;
            std::cout<<"--------------------------"<<std::endl; */                                                             
            cv::Mat b(1,out_size*kernel_size,CV_32F,data_kernel);
            cv::Mat padded(a);
           
            cv::Mat planes[] = {cv::Mat_<float>(padded), cv::Mat::zeros(padded.size(), CV_32F)};
            cv::Mat complexImg;
            cv::merge(planes, 2, complexImg);
            cv::dft(complexImg,complexImg);
            cv::split(complexImg, planes);

            cv::Mat padded1(b);
            cv::Mat planes2[] = {cv::Mat_<float>(padded1), cv::Mat::zeros(padded1.size(), CV_32F)};
            cv::Mat complexImg2;
            cv::merge(planes2, 2, complexImg2);
            cv::dft(complexImg2,complexImg2);
            cv::split(complexImg2, planes2);
    

            cv::Mat planes3[] = {cv::Mat::zeros(padded1.size(),CV_32F), cv::Mat::zeros(padded1.size(), CV_32F)};
            planes3[0]=planes[0].mul(planes2[0])-planes[1].mul(planes2[1]);
            planes3[1]=planes[0].mul(planes2[1])+planes[1].mul(planes2[0]);

            cv::merge(planes3, 2, complexImg);
            cv::idft(complexImg,complexImg,CV_DXT_INV_SCALE);
            cv::split(complexImg, planes);
            //std::cout<<planes[0]<<std::endl;
            //std::cout<<"+++++++++++++++++++++++++++++++"<<std::endl;
            //std::cout<<planes[0]<<std::endl;
            //std::cout<<"*******************************"<<std::endl;
            //std::cout<<"output"<<std::endl;
            for(int m = kernel_size-1,i=0; m < out_size*kernel_size ;m += kernel_size)
            {
                 *(output +i)=planes[0].at<float>(0,m);
                 i++;
                //std::cout<<*(output+i)<<std::endl;
            }
            //for(int iii=0;iii<4;iii++)
            //std::cout<<*(output+iii)<<std::endl;
    
        }
    //}
    
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章