P6773 [NOI2020] 命运(线段树合并优化dp)

· · 题解

P6773 [NOI2020] 命运

暴力:

我:暴力容斥,钦定一些限制不满足,复杂度 O(2^mm\log n),期望得分 32pts

gxy001:O(n^2) dp,期望得分 48pts

_sys:虚树优化,复杂度O(min(n,m)^2) ,期望得分 72pts

(我太菜了呀!)

正解:

性质:下端点相同的限制只有上端点深度最大的有用。

dp_{u,i} 表示下端点在 u 的子树内且未被满足的限制的上端点深度最大的为 i 时,u 子树内赋值的方案数。特别地,dp_{u,0} 为所有限制都满足的方案数。答案即为 dp_{1,0}

dp_{u,i}'=\sum_{j=0}^{dep_u}dp_{u,i}dp_{v,j}+\sum_{j=0}^idp_{u,i}dp_{v,j}+\sum_{j=0}^{i-1}dp_{u,j}dp_{v,i}

第一个 \sum(u,v) 为 1 的情况,后两个是为 0 的情况。用前缀和优化一下可得

dp_{u,i}'=dp_{u,i}(sum_{v,dep_u}+sum_{v,i})+dp_{v,i}sum_{u,i-1}

用线段树合并优化这个 dp 式子。先查询 sum_{v,dep_u},merge 的时候先处理左区间,处理的时候顺便求出区间和,处理右区间的时候就可以直接调用了。

#include<bits/stdc++.h>
#define mid ((l+r)>>1)
using namespace std;
namespace FGF
{
    int n,m;
    const int N=5e5+5,mo=998244353;
    vector<int> g[N],lm[N];
    int dep[N],rt[N],ls[N*32],rs[N*32],sum[N*32],num,ta[N*32];
    void updat(int &ro,int l,int r,int x)
    {
        if(!ro)ro=++num,ta[ro]=1;
        if(l==r){sum[ro]=1;return;}
        x<=mid?updat(ls[ro],l,mid,x):updat(rs[ro],mid+1,r,x);
        sum[ro]=(sum[ls[ro]]+sum[rs[ro]])%mo;
    }
    inline void pushdown(int x)
    {
        if(ta[x]!=1)
        {
            if(ls[x])ta[ls[x]]=1ll*ta[ls[x]]*ta[x]%mo,sum[ls[x]]=1ll*sum[ls[x]]*ta[x]%mo;
            if(rs[x])ta[rs[x]]=1ll*ta[rs[x]]*ta[x]%mo,sum[rs[x]]=1ll*sum[rs[x]]*ta[x]%mo;
            ta[x]=1;
        }
    }
    int merge(int x,int y,int l,int r,int &s1,int &s2)
    {
        if(!x&&!y)return 0;
        if(!x)
        {
            s1=(s1+sum[y])%mo,ta[y]=1ll*ta[y]*s2%mo,sum[y]=1ll*sum[y]*s2%mo;
            return y;
        }
        if(!y)
        {
            s2=(s2+sum[x])%mo,ta[x]=1ll*ta[x]*s1%mo,sum[x]=1ll*sum[x]*s1%mo;
            return x;
        }
        if(l==r)
        {
            int tmp=sum[x];
            s1=(s1+sum[y])%mo,sum[x]=(1ll*sum[x]*s1%mo+1ll*sum[y]*s2%mo)%mo,s2=(s2+tmp)%mo;
            return x;
        }
        pushdown(x),pushdown(y);
        ls[x]=merge(ls[x],ls[y],l,mid,s1,s2),rs[x]=merge(rs[x],rs[y],mid+1,r,s1,s2);
        sum[x]=(sum[ls[x]]+sum[rs[x]])%mo;
        return x;
    }
    int query(int ro,int l,int r,int R)
    {
        if(!ro)return 0;
        if(r<=R)return sum[ro];
        pushdown(ro);
        return (query(ls[ro],l,mid,R)+(R>mid?query(rs[ro],mid+1,r,R):0))%mo;
    }
    void dfs(int u,int f)
    {
        dep[u]=dep[f]+1;
        int mxd=0,s1=0,s2=0;
        for(auto v:lm[u])mxd=max(dep[v],mxd);
        updat(rt[u],0,n,mxd);
        for(auto v:g[u])
            if(v!=f)dfs(v,u),s2=0,s1=query(rt[v],0,n,dep[u]),rt[u]=merge(rt[u],rt[v],0,n,s1,s2);
    }
    void work()
    {
        scanf("%d",&n);
        for(int i=1,u,v;i<n;i++)
            scanf("%d%d",&u,&v),g[u].push_back(v),g[v].push_back(u);
        scanf("%d",&m);
        for(int i=1,u,v;i<=m;i++)
            scanf("%d%d",&u,&v),lm[v].push_back(u);
        dfs(1,0);
        printf("%d\n",query(rt[1],0,n,0));
    }
}
int main()
{
    FGF::work();
    return 0;
}