hihoCoder 1035 自駕旅行 樹形DP

題目鏈接:http://hihocoder.com/problemset/problem/1035

題目顯然是一個樹形DP,我們用dp[ i ][ j ]表示已經詢問了子樹i的所有關鍵節點,人車的一個狀態。其中

j==0:人去,不管人是否回來

j==1:人去,人一定要回來

j==2:人車都去,人車都要回來

j==3:人車都去,人一定要回來,車不管

j==4:人車都去,不管人車最後是否回來。

注意,這裏顯然有 dp[i][1]>=dp[i][0]  ,dp[i][1]>=dp[i][2]>=dp[i][3]>=dp[i][4]

我們先考慮比較容易計算的,我們記人走邊(fa,son)的花費爲w1(fa,son),車爲w2(fa,son),顯然

  • dp[fa][1] = ∑ dp[son][1] + 2*w1(fa,son)
  • dp[fa][2] = ∑ min( dp[son][1] +2* w1(fa,son) , dp[son][2] + 2*w2(fa,son) )
現在我們考慮複雜一點的情況,dp[fa][0]和dp[fa][3]

對dp[fa][0],其實就是有一個邊選擇爲 dp[son][0] + w1(fa,son) 而其他的都選擇 dp[son][1] + 2*w1(fa,son);或者都選後者,我們記

  • temp = min(dp[son][0] - dp[son][1] - w1(fa,son) ),temp = min(temp,0)
  • 則 dp[fa][0] = dp[fa][1] + temp;
現在解決dp[fa][3],爲了解釋的方便,我們記 t2 = min( dp[son][1] +2* w1(fa,son) , dp[son][2] + 2*w2(fa,son) )

dp[fa][3]的決策方式,其實就是一個選擇 w2(fa,son) + dp[son][3] + w1(fa,son) ,其他的選t2,或者都選t2,這樣我們得到dp[fa][3]的轉移方程

  • t3=0,t3=min( t3 , w1(fa,son)+ w2(fa,son)+ dp[son][3] - t2 );
  • dp[fa][3] = dp[fa][2] + t3;
dp[fa][4]則是一個比較複雜的過程,我們從走的最後一棵子樹考慮

1、走最後一棵子樹時,還有車

這樣的情況還是很好計算的,最後一棵子樹的選擇是 min( w1(fa,son) + dp[son][0] , w2(fa,son) + dp[son][4] ) 而其他選擇都是t2。或者全部選擇t2。

2、走最後一棵子樹時,沒有車了。

這種情況大概就是,最後一棵子樹的決策是w1(fa,son)+w2(fa,son)+dp[son][3];其他的,某一棵子樹,人車走下去,然後只有人回來了,剩下的決策是t2。

這個情況是最麻煩的,因爲有兩棵特殊的子樹,我們類似上面的差值統計的時候,要避免兩個差值代表同一個子樹。爲了解決這個問題,我最開始的方法是,先後將兩個差值的優先級設爲不同,這樣統計的話,可以避免重複子樹。但是因爲順序處理沒處理好,我的代碼最開始有BUG。但是因爲能A這題,所以我沒注意到。

這裏多謝swwlqw的指點,這個問題現在已經改正。

我們記

  • ff1 = w1(fa,son)+ w2(fa,son)+ dp[son][3] - t2;
  • ff2 = w1(fa,son)+ dp[son][0] - t2;
我們的目標是找到不同son,使得ff1 + ff2 最小。我們可以先去找ff1的最小和次小,並記錄達到最小ff1的兒子match。再遍歷一下所有兒子,如果此時的son就是match,那麼就配合ff1的次小值,否則配合最小值。

//#pragma comment(linker, "/STACK:102400000,102400000")
#include<cstdio>
#include<cstring>
#include<vector>
#include<queue>
#include<cmath>
#include<cctype>
#include<string>
#include<algorithm>
#include<iostream>
#include<ctime>
#include<map>
#include<set>
using namespace std;
#define MP(x,y) make_pair((x),(y))
#define PB(x) push_back(x)
typedef long long LL;
//typedef unsigned __int64 ULL;
/* ****************** */
const int INF=100011122;
const double INFF=1e100;
const double eps=1e-8;
const int mod=1000000007;
const int NN=1000010;
const int MM=1000010;
/* ****************** */

struct G
{
    int v,w1,w2,next;
}E[NN*2];
int p[NN],T;
bool vis[NN];
LL dp[NN][5];
int si[NN];

void add(int u,int v,int w1,int w2)
{
    E[T].v=v;
    E[T].w1=w1;
    E[T].w2=w2;
    E[T].next=p[u];
    p[u]=T++;
}

void dfs(int u,int fa)
{
    int i,v;
    si[u]=vis[u];

    LL temp=0,t2,t3=0,t41=0;
    LL f1,f2,ff1,ff2;
    int match = -1;
    f1=f2=1LL<<50;

    dp[u][1]=0;
    dp[u][2]=0;
    dp[u][4]=0;
    for(i=p[u];i+1;i=E[i].next)
    {
        v=E[i].v;
        if(v==fa)continue;
        dfs(v,u);
        si[u]+=si[v];
        if(si[v]>0)
        {
            temp=min(temp,dp[v][0]-dp[v][1]-E[i].w1);

            dp[u][1]+=dp[v][1]+E[i].w1*2;

            t2=min(dp[v][1]+E[i].w1*2,dp[v][2]+E[i].w2*2);

            dp[u][2]+=t2;

            t3=min(t3,E[i].w1+E[i].w2+dp[v][3]-t2);

            dp[u][4]+=t2;
            t41=min(t41, min(dp[v][0]+E[i].w1,dp[v][4]+E[i].w2) - t2 );

            ff1=E[i].w1+E[i].w2+dp[v][3]-t2;
            if(ff1<f1)
            {
                match = v;
                f2 = f1;
                f1 = ff1;
            }
            else
                f2=min(f2,ff1);
        }
    }

    dp[u][0]=dp[u][1]+temp;
    dp[u][3]=dp[u][2]+t3;
    dp[u][4]+=t41;
    dp[u][4]=min(dp[u][4],dp[u][3]);

    for(i=p[u];i+1;i=E[i].next)
    {
        v=E[i].v;
        if(v==fa)continue;
        if(si[v]>0)
        {
            t2=min(dp[v][1]+E[i].w1*2,dp[v][2]+E[i].w2*2);
            ff2=E[i].w1+dp[v][0]-t2;
            if(v==match)
                dp[u][4] = min(dp[u][4], dp[u][2] + f2 + ff2);
            else
                dp[u][4] = min(dp[u][4], dp[u][2] + f1 + ff2);
        }
    }
}

int main()
{
    int n,m,i,u,v,w1,w2;
    while(scanf("%d",&n)!=EOF)
    {
        memset(p,-1,sizeof(p));
        T=0;
        memset(vis,false,sizeof(vis));

        for(i=1;i<n;i++)
        {
            scanf("%d%d%d%d",&u,&v,&w1,&w2);
            add(u,v,w1,w2);
            add(v,u,w1,w2);
        }
        scanf("%d",&m);
        for(i=1;i<=m;i++)
        {
            scanf("%d",&u);
            vis[u]=true;
        }

        dfs(1,-1);

        cout<<min(dp[1][0],dp[1][4])<<endl;
    }
    return 0;
}



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章