原題地址
題意:給一個數m和一個空數組a,每次可以選1~m中任意一個數加入到a中,求最後a數組中所有數的gcd爲1時的期望長度;
令dp[i] 表示當前數組的gcd爲i時,還期望添加dp[i]個數使得所有數的gcd爲1;那麼初始條件dp[1] = 0;
狀態轉移公式:
解釋一下,1表示添一個數;求和表示從1~m中選一個數加入數組,期望是dp[gcd(i,j)],選到的概率是1/m。直接做複雜度O(m^2),過不了,所以要優化,觀察這個式子,gcd(i,j)肯定會出現重複項,比如,當i=4時,j=6和j=8都會讓gcd(i,j)=2。從這裏入手,假設gcd(i,j)的結果是a1,a2…ax共x項,每一項出現的次數是c1,c2…cx;那麼原式就是:
並且a[1~x]全部是i的因子。由於dp[i]未知,所以存在a等於i,這個式子還要化解;
此時a1~ax不存在值爲i的數;整理得:
求取val:對於枚舉i的因子,對於當前枚舉到的因子j(注意跳過j==1)
求它的c值,即求[1~m]中有多少個數k滿足gcd(k,i)=j,再化簡就是[1 ~ m/j]中有多少個數k和i/j互質。這樣枚舉i的因子,求取的複雜度變成了O(logn),總共O(n*logn)。
#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<map>
#include<bitset>
#include<queue>
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define lowbit(x) ((x)&(-x))
const int maxn = 1e5+10;
const int maxm = 200 + 10;
const int inf_max = 0x3f3f3f;
const int mod = 1000000007;
const double eps = 1e-4;
using namespace std;
typedef long long ll;
inline int read() {
int x = 0,f = 1;
char ch = getchar();
while(ch<'0' || ch>'9') {if(ch == '-') f = -1;ch = getchar();}
while(ch>='0' && ch<='9') x = x*10 + ch -'0',ch = getchar();
return x*f;
}
int m,cnt[maxn],ct,sum,dp[maxn];
vector<int>fac[maxn];
int fast_power(int base,int power) {
int ret = 1;
while(power) {
if(power&1) ret = 1ll*ret*base % mod;
base = 1ll*base*base % mod;
power >>= 1;
}
return ret;
}
void get_fac(int x) {
int tt = x;
for(int i = 2;i*i <= x; ++i) {
if(x%i == 0) {
fac[tt].push_back(i);
while(x%i == 0) x /= i;
}
}
if(x > 1) fac[tt].push_back(x);
}
void dfs(int pos,int use,int cur,int up,int now) {
if(pos == fac[now].size()) {
if(use&1) sum -= up/cur;
else sum += up/cur;
return ;
}
dfs(pos+1,use,cur,up,now);
dfs(pos+1,use+1,cur*fac[now][pos],up,now);
}
int main()
{
for(int i = 1;i <= 100000; ++i) get_fac(i);
m = read();
dp[1] = 0;
for(int i = 2; i <= m; ++i) {
int tot = m/i,al = 0;
for(int j = 1;j*j <= i ;++j) {
if(i%j) continue;
sum = 0;
dfs(0,0,1,m/j,i/j);
//printf("1:%d\n",sum);
al = (al+1ll*sum*dp[j]%mod)%mod;
if(i/j != j && j != 1) {
sum = 0;
dfs(0,0,1,m/(i/j),j);
al = (al+1ll*sum*dp[i/j]%mod)%mod;
// printf("2:%d\n",sum);
}
}
dp[i] = 1ll*(al+m)%mod*fast_power(m-tot,mod-2)%mod;
}
int ans = 0;
for(int i = 1;i <= m; ++i) ans = (ans + dp[i])%mod;
ans = (1ll*ans*fast_power(m,mod-2)%mod+1)%mod;
cout<<ans<<endl;
return 0;
}