P5024

· · 个人记录

[NOIP2018 提高组] 保卫王国

调了一个下午一个小时加一节晚自习终于写完了。。。

自找的麻烦,自己给自己创造了超级多需要注意的细节。

先是一个整体的思路。

我们首先往暴力的 O(n^2) 想,只要 DP,然后修改节点值之后再 DP,就行了,当然注意一些与 LCA 的关系和细节,目前遇到的有分类讨论 LCA 等于链的一端一种,反之一种,最后根节点等于端点的一种。

暴力之后就能够拿到 44pts 的好成绩。

然后是正解。一眼动态DP裸题倍增优化。

然后是一些比较麻烦的事情,我们这个待修改的倍增 DP 要注意的细节超级多。

首先定义 g(i,j,p,q),表示 i 节点的状态为 pi2^j 祖先的状态为 q 的时候,i2^j 祖先的子树中除去子树 i 的值,剩下的值。

然后我们就可以瞎搞(因为这样想过之后方法就是会很麻烦)。

转移随便搞搞罢,注意时间复杂度,,我一开始居然写退化了。。。(丢脸)然后就是注意细节,虽然我们定义两个节点的状态 g 是确定的,但不意味着使用 f 转移的时候就要按照这个状态转移,因为本质上我们还是在还原一个特殊的状态,具体可以看预处理的一段注释。倍增转移的时候讨论的是中间的节点状态。

首先分类讨论,如果链的一端就是 LCA 的话,我们从较低的点往上跳,然后跳到 LCA 下面再转移给 LCA,然后再从 LCA 跳到根节点。

如果是一般情况,我们从两端开始跳到 LCA 下面,LCA 单独转移(注意还要加上其他子节点),然后从 LCA 再转移到根节点。

实现的细节很多。。。其实主要都是跳的过程中有点重合的存在,需要单独拎出来。然后就是重新修改 sum(i,j) 的值需要用一个记录数组防止退化,使用 memset 必死。

对拍+查错弄了几乎一整个晚自习,自闭了。代码弄到 6.83KB,自闭了。至今写过的最长的代码。。。

时间复杂度 O(n\log n)

代码:

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll inf=(1ll<<40),N=1e5;

ll n,m,u,v,tot,cnt,a,b,x,y,ans;

char type[5];

ll lg[N+5],dt[N+5],fa[N+5][20],g[N+5][20][2][2],pos[N+5],f[N+5][2],sum[N+5][2];

bool vis[N+5][2];

ll cst[N+5],ver[N*2+5],nxt[N*2+5],head[N+5];

ll getlca(ll a,ll b) {
    if(dt[a]<dt[b]) swap(a,b);
    while(dt[a]>dt[b]) a=fa[a][lg[dt[a]-dt[b]]-1];
    if(a==b) return a;
    for(ll k=lg[dt[a]];k>=0;k--) {
        if(fa[a][k]!=fa[b][k]) {
            a=fa[a][k];b=fa[b][k];
        }
    }
    return fa[a][0];
}

void dfs(ll p,ll fath) {
    dt[p]=dt[fath]+1;fa[p][0]=fath;
    if(fath!=0) {
        g[p][0][1][1]=f[fath][1]-min(f[p][0],f[p][1]);//注意这里 -min 的还原
        g[p][0][1][0]=f[fath][0]-f[p][1];//我们转移 f 的时候也特判了,这里同理
        g[p][0][0][0]=inf;
        g[p][0][0][1]=f[fath][1]-min(f[p][0],f[p][1]);//同理
    }
    for(ll i=1;i<=lg[dt[p]];i++) {
        fa[p][i]=fa[fa[p][i-1]][i-1];
        g[p][i][0][0]=min(g[p][i-1][0][0]+g[fa[p][i-1]][i-1][0][0],g[p][i-1][0][1]+g[fa[p][i-1]][i-1][1][0]);
        g[p][i][0][1]=min(g[p][i-1][0][0]+g[fa[p][i-1]][i-1][0][1],g[p][i-1][0][1]+g[fa[p][i-1]][i-1][1][1]);
        g[p][i][1][0]=min(g[p][i-1][1][0]+g[fa[p][i-1]][i-1][0][0],g[p][i-1][1][1]+g[fa[p][i-1]][i-1][1][0]);
        g[p][i][1][1]=min(g[p][i-1][1][0]+g[fa[p][i-1]][i-1][0][1],g[p][i-1][1][1]+g[fa[p][i-1]][i-1][1][1]);
    }
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dfs(ver[i],p);
    }
}

ll dp(ll p,ll stat,ll fath) {
    if(vis[p][stat]) return f[p][stat];vis[p][stat]=1;
    f[p][stat]=stat*cst[p];
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        if(stat) f[p][stat]+=min(dp(ver[i],0,p),dp(ver[i],1,p));
        else f[p][stat]+=dp(ver[i],1,p);
    }
    return f[p][stat];
}

void add(ll u,ll v) {
    ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

void restitute() {
//  printf("cnt=%lld\n",cnt);
    for(ll i=1;i<=cnt;i++) {
        sum[pos[i]][0]=sum[pos[i]][1]=inf;
    }
    cnt=0;
}

void init() {
    for(ll i=1;i<=n;i++) {
        sum[i][0]=sum[i][1]=inf;
    }
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=-(x%10)+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

void writeln(ll x) {
    write(x);putchar('\n');
}

int main() {

    n=read();m=read();cin>>type;

    for(ll i=1;i<=n;i++) {
        lg[i]=lg[i-1]+(1<<lg[i-1]==i);
    }

    for(ll i=1;i<=n;i++) {
        cst[i]=read();
    }

    for(ll i=1;i<n;i++) {
        u=read();v=read();
        add(u,v);add(v,u);
    }

    dp(1,0,0);dp(1,1,0);
    dfs(1,0);dfs(1,0);

//  for(ll i=1;i<=n;i++) {
//      for(ll j=0;j<=lg[dt[i]];j++) {
//          printf("When p=%lld fa=%lld j=%lld stat=0 stat_fa=0 g=%lld\n",i,fa[i][j],j,g[i][j][0][0]);
//          printf("When p=%lld fa=%lld j=%lld stat=0 stat_fa=1 g=%lld\n",i,fa[i][j],j,g[i][j][0][1]);
//          printf("When p=%lld fa=%lld j=%lld stat=1 stat_fa=0 g=%lld\n",i,fa[i][j],j,g[i][j][1][0]);
//          printf("When p=%lld fa=%lld j=%lld stat=1 stat_fa=1 g=%lld\n",i,fa[i][j],j,g[i][j][1][1]);
//      }
//  }
//  
    init();

    for(ll i=1;i<=m;i++) {
        a=read();x=read();b=read();y=read();
        if(dt[a]<dt[b]) {swap(a,b);swap(x,y);}
        ll lca=getlca(a,b);
    //  printf("a=%lld , b=%lld , lca=%lld\n",a,b,lca);
        if(b==lca) {
        //  printf("It is in the case 1:\n");
            sum[a][x]=f[a][x];pos[++cnt]=a;
        //  printf("sum[%lld][%lld]=%lld\n",a,x,sum[a][x]);
            for(ll j=lg[dt[a]-dt[b]];j>=0;j--) {
                if(dt[fa[a][j]]>dt[b]) {
                    sum[fa[a][j]][0]=min(g[a][j][0][0]+sum[a][0],g[a][j][1][0]+sum[a][1]);
                    sum[fa[a][j]][1]=min(g[a][j][0][1]+sum[a][0],g[a][j][1][1]+sum[a][1]);
                    a=fa[a][j];pos[++cnt]=a;
                }
            }
        //  printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",a,sum[a][0],a,sum[a][1]);
            if(y==0) sum[b][y]=sum[a][1]+g[a][0][1][y];
            if(y==1) sum[b][y]=min(sum[a][0]+g[a][0][0][y],sum[a][1]+g[a][0][1][y]);
            pos[++cnt]=b;
        //  printf("sum[%lld][%lld]=%lld ? sum[%lld][0]=%lld , sum[%lld][1]=%lld , g[%lld][0][0][%lld]=%lld , g[%lld][0][1][%lld]=%lld\n",b,y,sum[b][y],a,sum[a][0],a,sum[a][1],a,y,g[a][0][0][y],a,y,g[a][0][1][y]);
            for(ll j=lg[dt[b]];j>=0;j--) {
                if(dt[fa[b][j]]>dt[1]) {
                    sum[fa[b][j]][0]=min(g[b][j][0][0]+sum[b][0],g[b][j][1][0]+sum[b][1]);
                    sum[fa[b][j]][1]=min(g[b][j][0][1]+sum[b][0],g[b][j][1][1]+sum[b][1]);
                    b=fa[b][j];pos[++cnt]=b;
                }
            }
            if(b!=1) {
                sum[1][0]=min(sum[b][0]+g[b][0][0][0],sum[b][1]+g[b][0][1][0]);
                sum[1][1]=min(sum[b][0]+g[b][0][0][1],sum[b][1]+g[b][0][1][1]);
                pos[++cnt]=1;
            }
            ans=min(sum[1][0],sum[1][1]);
        }
        else {
        //  printf("It is in the case 2:\n");
            sum[a][x]=f[a][x];pos[++cnt]=a;
        //  printf("sum[%lld][%lld]=%lld\n",a,x,sum[a][x]);
            for(ll j=lg[dt[a]-dt[lca]];j>=0;j--) {
                if(dt[fa[a][j]]>dt[lca]) {
                    sum[fa[a][j]][0]=min(g[a][j][0][0]+sum[a][0],g[a][j][1][0]+sum[a][1]);
                    sum[fa[a][j]][1]=min(g[a][j][0][1]+sum[a][0],g[a][j][1][1]+sum[a][1]);
                    a=fa[a][j];pos[++cnt]=a;
                }
            }
        //  printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",a,sum[a][0],a,sum[a][1]);
            sum[b][y]=f[b][y];pos[++cnt]=b;
        //  printf("sum[%lld][%lld]=%lld\n",b,y,sum[b][y]);
            for(ll j=lg[dt[b]-dt[lca]];j>=0;j--) {
                if(dt[fa[b][j]]>dt[lca]) {
                    sum[fa[b][j]][0]=min(g[b][j][0][0]+sum[b][0],g[b][j][1][0]+sum[b][1]);
                    sum[fa[b][j]][1]=min(g[b][j][0][1]+sum[b][0],g[b][j][1][1]+sum[b][1]);
                    b=fa[b][j];pos[++cnt]=b;
                }
            }
        //  printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",b,sum[b][0],b,sum[b][1]);
            sum[lca][1]=cst[lca];sum[lca][0]=0;
        //  printf("lca=%lld\n",lca);
            for(ll j=head[lca];j;j=nxt[j]) {
                if(ver[j]==fa[lca][0]) continue;
                if(ver[j]==a||ver[j]==b) {
                    sum[lca][0]+=sum[ver[j]][1];
                    sum[lca][1]+=min(sum[ver[j]][0],sum[ver[j]][1]);
                //  printf("ver=%lld , sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",ver[j],ver[j],sum[ver[j]][0],ver[j],sum[ver[j]][1]);
                }
                else {
                    sum[lca][0]+=f[ver[j]][1];
                    sum[lca][1]+=min(f[ver[j]][0],f[ver[j]][1]);
                }
            }
            pos[++cnt]=lca;
        //  printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",lca,sum[lca][0],lca,sum[lca][1]);
            for(ll j=lg[dt[lca]];j>=0;j--) {
                if(dt[fa[lca][j]]>dt[1]) {
                    sum[fa[lca][j]][0]=min(g[lca][j][0][0]+sum[lca][0],g[lca][j][1][0]+sum[lca][1]);
                    sum[fa[lca][j]][1]=min(g[lca][j][0][1]+sum[lca][0],g[lca][j][1][1]+sum[lca][1]);
                    lca=fa[lca][j];pos[++cnt]=lca;
                }
            }
            if(lca!=1) {
                sum[1][0]=sum[lca][1]+g[lca][0][1][0];
                sum[1][1]=min(sum[lca][0]+g[lca][0][0][1],sum[lca][1]+g[lca][0][1][1]);
                pos[++cnt]=1;
            }
            ans=min(sum[1][0],sum[1][1]);
        }
//      for(ll i=1;i<=n;i++) {
//          printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",i,sum[i][0],i,sum[i][1]);
//      }
        if(ans>=inf) writeln(-1);
        else writeln(ans);
        restitute();
    }

    return 0;
}

代码(暴力):

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll inf=(1ll<<40),N=1e5;

ll n,m,u,v,tot,cnt,a,b,x,y,ans,lca;

char type[5];

ll lg[N+5],dt[N+5],fa[N+5][20],g[N+5][2],pos[N+5],f[N+5][2],sum[N+5][2];

ll cst[N+5],ver[N*2+5],nxt[N*2+5],head[N+5];

void dfs(ll p,ll fath) {
    f[p][1]=cst[p];dt[p]=dt[fath]+1;fa[p][0]=fath;
    for(ll i=1;i<=lg[dt[p]];i++) {
        fa[p][i]=fa[fa[p][i-1]][i-1];
    }
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dfs(ver[i],p);
        f[p][1]+=min(f[ver[i]][0],f[ver[i]][1]);
        f[p][0]+=f[ver[i]][1];
    }
}

ll getlca(ll a,ll b) {
    if(dt[a]<dt[b]) swap(a,b);
    while(dt[a]>dt[b]) a=fa[a][lg[dt[a]-dt[b]]-1];
    if(a==b) return a;
    for(ll k=lg[dt[a]];k>=0;k--) {
        if(fa[a][k]!=fa[b][k]) {
            a=fa[a][k];b=fa[b][k];
        }
    }
    return fa[a][0];
}

void dp(ll p,ll fath) {
    if(p==a) return;
    g[p][1]=cst[p];g[p][0]=0;
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dp(ver[i],p);
        if(p!=b) {
            g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);
            g[p][0]+=g[ver[i]][1];
        }
        else {
            if(y==1) {g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);g[p][0]=inf;}
            if(y==0) {g[p][0]+=g[ver[i]][1];g[p][1]=inf;}
        }
    }
}

void _dp(ll p,ll fath) {
    if(p==lca) return;
    g[p][1]=cst[p];g[p][0]=0;
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        _dp(ver[i],p);
        g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);
        g[p][0]+=g[ver[i]][1];
    }
}

void __dp(ll p,ll fath) {
    if(p==a||p==b) return;
    g[p][1]=cst[p];g[p][0]=0;
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        __dp(ver[i],p);
        g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);
        g[p][0]+=g[ver[i]][1];
    }
}

void init() {
    for(ll i=1;i<=n;i++) {
        g[i][0]=g[i][1]=inf;
    }
}

void add(ll u,ll v) {
    ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=-(x%10)+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

void writeln(ll x) {
    write(x);putchar('\n');
}

int main() {

    n=read();m=read();cin>>type;

    for(ll i=1;i<=n;i++) {
        lg[i]=lg[i-1]+(1<<lg[i-1]==i);
    }

    for(ll i=1;i<=n;i++) {
        cst[i]=read();
    }

    for(ll i=1;i<n;i++) {
        u=read();v=read();
        add(u,v);add(v,u);
    }

    dfs(1,0);

//  for(ll i=1;i<=n;i++) {
//      printf("f[%lld][0]=%lld , f[%lld][1]=%lld\n",i,f[i][0],i,f[i][1]);
//  }

    for(ll i=1;i<=m;i++) {
        init();
        a=read();x=read();b=read();y=read();
        if(dt[a]<dt[b]) {swap(a,b);swap(x,y);}
        lca=getlca(a,b);
        if(lca==b) {
            g[a][x]=f[a][x];dp(lca,fa[lca][0]);
        }
        else {
            g[a][x]=f[a][x];g[b][y]=f[b][y];
            __dp(lca,fa[lca][0]);
        }
        _dp(1,0);
        if(b!=1) ans=min(g[1][0],g[1][1]);
        else ans=g[b][y];
//      printf("ans=%lld\n",ans);
//      for(ll j=1;j<=n;j++) {
//          printf("g[%lld][0]=%lld , g[%lld][1]=%lld\n",j,g[j][0],j,g[j][1]);
//      }
        if(ans>=inf) writeln(-1);
        else writeln(ans);
    }

    return 0;
}

代码(数据生成):

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#define ll long long
using namespace std;

ll n,m,fa,a,b,x,y;

ll random(ll x) {
    return rand()*rand()%x;
}

int main() {
    freopen("r.in","w",stdout);

    srand((unsigned)time(0));

    n=random(2000)+5;m=random(2000)+3;

    printf("%lld %lld C3\n",n,m);

    for(ll i=1;i<=n;i++) {
        printf("%lld ",random(5)+1);
    }

    printf("\n");

    for(ll i=1;i<n;i++) {
        fa=random(i)+1;printf("%lld %lld\n",fa,i+1);
    }

    for(ll i=1;i<=m;i++) {
        a=random(n)+1;b=a;
        while(b==a) b=random(n)+1;
        x=random(2);y=random(2);
        printf("%lld %lld %lld %lld\n",a,x,b,y);
    }

    fclose(stdout);
    return 0;
}