關於母函數網上有一大堆解析,不過解析具體代碼的好像不多,這裏就簡要介紹一下好了。僅供初學者參考,老鳥請路過-.-
這篇博客雖然寫的是hdu 1398,不過我們還是從hdu 1028開始比較好,因爲這道題是最赤裸裸的母函數模板。
母函數說白了就是計算幾個式子的乘積,如
(1 + x + x^1 + x^2)(1 + x^2) = 1 + x + 2x^2 + x^3 + x^4
通過母函數我們要得出的就是右式。我們設兩個變量c1和c2,c2是臨時數組,先不管它,我們將c1的下標設爲指數,值設爲這個指數的係數,如5x^3就是c1[3] = 5。
先看一看代碼
#include<iostream>
using namespace std;
#define maxn 125
int c1[maxn], c2[maxn];
int main()
{
int n;
while(~scanf("%d", &n))
{
for(int i = 0; i <= n; i++) // --- 第一步
{
c1[i] = 1;
c2[i] = 0;
}
for(int i = 2; i <= n; i++) // --- 第二步
{
for(int j = 0; j <= n; j++) // --- 第三步
for(int k = 0; k+j <= n;k+=i) // --- 第四步
c2[k+j] += c1[j]; // --- 第五步
for(int i = 0; i <= n; i++) // 第六步
{
c1[i] = c2[i];
c2[i] = 0;
}
}
printf("%d\n", c1[n]);
}
return 0;
}
hdu 1028要計算的是(1 + x + x^2 + x^3 + ...) * (1 + x ^2 + x^4 + ...) * (1 + x^3 + x^6 + ...)
與手動計算一樣,先是第一個括號乘以第二個括號,得到結果後再繼續與第三個括號相乘,如此下去。
那麼我們首先初始化第一個括號的參數,也就是上述代碼中的第一步。然後從第二個括號開始乘,那麼就是上述的第二步。接下來的j依次指向第一個括號中的每一個數。
第四步是關鍵,k指向的是第二個括號中的每一個數,說到這裏可能還是比較模糊,我們就來手動執行一下。我們假設hdu1028的n現在是2
(1 + x + x^2) * (1 + x^2)
首先j指向第一個括號中的第一個數,也就是1,然後依次與第二個括號中的數相乘,那麼k循環一遍後我們得到了1 + x^2,接着我們把它記錄到c2這個臨時數組中
所以是c2[j+k] += c1[j]。j+k指的是相乘後的指數,c1[j]指的是【第一個括號中,指數爲j的數的係數】。
爲什麼是這樣寫的呢?個人所見,可以從兩方面理解:一是因爲,第二個括號中的係數都是1,因爲個數的含義體現在了指數上,這個很好理解。二是因爲,這個表達式前兩個括號相加後,是合併到第一個括號中的,然後繼續與"第二個“括號相乘,這裏的"第二個"括號實際上是第三個括號。
例如這個表達式:(1 + 2x + 3x^2) * (1 + x^2),我們會得到如下值:
(1 + 2x + 3x^2) * (1 + x^2) = 1 + x^2 + 2x + 2x^3 + 3x^2 + 2x^4,然後是合併同指數的係數,得到1 + 2x + 4x^2 + 2x^3 + 3x^4
所以是c2[j+k] += c1[j]。從這裏可以看出,母函數就是經過不斷計算,將指數一層層疊加到第一個括號中。
現在c2數組存了更新後的"第一個括號的係數",所以要放回c1中,供下次相乘使用。
現在回過來看看hdu 1398,(1 + x + x^2 + x^3 + ...) * (1 + x^2 + x^4 + ...) * (1 + x^9 + x^18 + ...)
同樣的,初始化0~n都是1,然後從第二個括號開始計算,因爲第一個括號永遠都是n+1項,所以j是從0到n。
第i個括號中的指數增量是i*i,所以代碼就成型了。
#include<iostream>
using namespace std;
#define maxn 310
int c1[maxn], c2[maxn];
int main()
{
int n;
while(~scanf("%d", &n) && n)
{
memset(c1, 0, sizeof(c1));
memset(c2, 0, sizeof(c2));
for(int i = 0; i <= n; i++)
c1[i] = 1;
for(int i = 2; i*i <= n; i++)
{
for(int j = 0; j <= n; j++)
for(int k = 0; j + k <= n; k += i*i)
c2[j+k] += c1[j];
for(int i = 0; i <= n; i++)
{
c1[i] = c2[i];
c2[i] = 0;
}
}
printf("%d\n", c1[n]);
}
return 0;
}