CF1139D step to one(數論+概率dp)

原題地址
題意:給一個數m和一個空數組a,每次可以選1~m中任意一個數加入到a中,求最後a數組中所有數的gcd爲1時的期望長度;
令dp[i] 表示當前數組的gcd爲i時,還期望添加dp[i]個數使得所有數的gcd爲1;那麼初始條件dp[1] = 0;
狀態轉移公式:在這裏插入圖片描述
解釋一下,1表示添一個數;求和表示從1~m中選一個數加入數組,期望是dp[gcd(i,j)],選到的概率是1/m。直接做複雜度O(m^2),過不了,所以要優化,觀察這個式子,gcd(i,j)肯定會出現重複項,比如,當i=4時,j=6和j=8都會讓gcd(i,j)=2。從這裏入手,假設gcd(i,j)的結果是a1,a2…ax共x項,每一項出現的次數是c1,c2…cx;那麼原式就是:
在這裏插入圖片描述
並且a[1~x]全部是i的因子。由於dp[i]未知,所以存在a等於i,這個式子還要化解;
blog.csdnimg.cn/20200612213138589.png)
此時a1~ax不存在值爲i的數;整理得:
在這裏插入圖片描述
求取val:對於枚舉i的因子,對於當前枚舉到的因子j(注意跳過j==1)
求它的c值,即求[1~m]中有多少個數k滿足gcd(k,i)=j,再化簡就是[1 ~ m/j]中有多少個數k和i/j互質。這樣枚舉i的因子,求取的複雜度變成了O(logn),總共O(n*logn)。

#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<map>
#include<bitset>
#include<queue>
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define lowbit(x) ((x)&(-x))
const int maxn = 1e5+10;
const int maxm = 200 + 10;
const int inf_max = 0x3f3f3f;
const int mod = 1000000007;
const double eps = 1e-4;
using namespace std;
typedef long long ll;
inline int read() {
    int x = 0,f = 1;
    char ch = getchar();
    while(ch<'0' || ch>'9') {if(ch == '-') f = -1;ch = getchar();}
    while(ch>='0' && ch<='9') x = x*10 + ch -'0',ch = getchar();
    return x*f;
}
int m,cnt[maxn],ct,sum,dp[maxn];
vector<int>fac[maxn];
int fast_power(int base,int power) {
    int ret = 1;
    while(power) {
        if(power&1) ret = 1ll*ret*base % mod;
        base = 1ll*base*base % mod;
        power >>= 1;
    }
    return ret;
}
void get_fac(int x) {
    int tt = x;
    for(int i = 2;i*i <= x; ++i) {
        if(x%i == 0) {
            fac[tt].push_back(i);
            while(x%i == 0) x /= i;
        }
    }
    if(x > 1) fac[tt].push_back(x);
}
void dfs(int pos,int use,int cur,int up,int now) {
    if(pos == fac[now].size()) {
        if(use&1) sum -= up/cur;
        else sum += up/cur;
        return ;
    }
    dfs(pos+1,use,cur,up,now);
    dfs(pos+1,use+1,cur*fac[now][pos],up,now);
}
int main()
{
    for(int i = 1;i <= 100000; ++i) get_fac(i);
    m = read();
    dp[1] = 0;
    for(int i = 2; i <= m; ++i) {
        int tot = m/i,al = 0;
        for(int j = 1;j*j <= i ;++j) {
            if(i%j) continue;
            sum = 0;
            dfs(0,0,1,m/j,i/j);
            //printf("1:%d\n",sum);
            al = (al+1ll*sum*dp[j]%mod)%mod;  
            if(i/j != j && j != 1) {
                sum = 0;
                dfs(0,0,1,m/(i/j),j);
                al = (al+1ll*sum*dp[i/j]%mod)%mod;
               // printf("2:%d\n",sum);
            }
        }
        dp[i] = 1ll*(al+m)%mod*fast_power(m-tot,mod-2)%mod;
    }
    int ans = 0;
    for(int i = 1;i <= m; ++i) ans = (ans + dp[i])%mod;
    ans = (1ll*ans*fast_power(m,mod-2)%mod+1)%mod;
    cout<<ans<<endl;
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章