题解:P13241 「2.48sOI R1」格律树

· · 题解

虚树 + dp。

题目分析

先考虑特殊性质,有一条长为 d 的链不能出现 101,其它节点乱填。这时应该怎么办呢?

显然可以预处理一个状压 dp,f_{i,st} 表示此时长为 i,第 i-1 个为 st 的第一位,第 i 个为 st 的第二位。初始状态 f_{0,0}=1。转移只需枚举下一位为 x,当 2st+x=5 时禁止转移,否则加上去。

再考虑原问题。注意到 k 的总和是保证的,而且答案与非关键点几乎无关,因此考虑建出虚树,在虚树上 dp。记 dp_{u,st} 表示 ust 第一位,u父亲st 第二位的方案数。用父亲是因为只有父亲是唯一的,孩子可能有很多个。转移也比较显然,就是枚举虚树上的叶子节点 v 与状态 st',则根据加法原理和乘法原理,我们先统计此子树内到 u 状态 st 的方案数之和,再乘到 dp_{u,st} 上。但还有一个转移系数没有考虑。不妨改一下上面的状压 dp,f_{s,i,st} 表示此时长为 i,第 0 个为 s 的第一位,第 1 个为 s 的第二位,第 i-1 个为 st 的第一位,第 i 个为 st 的第二位。初始状态 f_{s,0,s}=1,转移类似。那么就可以把虚树拆成链用上边的方式求解。若一个节点已经的权值被钦定,那么就要把另一种选法的方案数设为 0,否则为 1。最终答案为 dp_{1,s} 之和,但注意不能让“1 的父亲”的值为 1。别忘了再乘上其它节点任意填的方案数。

于是就做完了。记 K=\sum k,时间复杂度 O(K \log K),瓶颈在建虚树。

AC Code

带快读板子,所以有点长。

#include<bits/stdc++.h>
//#include<bits/extc++.h>
//bool Mst;
using namespace std;
namespace wyzfastio
{
//#define usefio
#ifdef usefio
    namespace __getchar{const int bufsize=1<<20;char buf[bufsize<<1],*p1=buf,*p2=buf;inline char getchar(){return (p1==p2&&(p2=(p1=buf)+fread(buf,1,bufsize,stdin),p1==p2))?EOF:(*p1++);}}using __getchar::getchar;
    namespace __putchar{const int bufsize=1<<20;char buf[bufsize<<1],*p=buf;inline void putchar(const char ch){if(p-buf==bufsize) fwrite(buf,1,bufsize,stdout),p=buf;*p++=ch;}inline void flush(){fwrite(buf,1,p-buf,stdout);}}using __putchar::putchar;using __putchar::flush;
#endif
    /*---input---*/
    inline void read(unsigned long long &x){char ch=getchar();unsigned long long res=0;while(ch<'0'||ch>'9')ch=getchar();while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=getchar();x=res;}
    inline void read(unsigned int &x){char ch=getchar();unsigned int res=0;while(ch<'0'||ch>'9')ch=getchar();while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=getchar();x=res;}
    inline void read(long long &x){char ch=getchar();long long f=1;unsigned long long res=0;while(ch<'0'||ch>'9'){if(ch=='-')f=-f;ch=getchar();}while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=getchar();x=res*f;}
    inline void read(int &x){char ch=getchar();int f=1;unsigned res=0;while(ch<'0'||ch>'9'){if(ch=='-')f=-f;;ch=getchar();}while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=getchar();x=res*f;}
    inline void read(char &s){s=getchar();while(s==' '||s=='\n'||s=='\r') s=getchar();}
    inline int read(char *s){int i=0;char ch=getchar();while(ch==' '||ch=='\n'||ch=='\r') ch=getchar();while(!(ch==' '||ch=='\n'||ch=='\r'||ch==EOF)) s[++i]=ch,ch=getchar();s[i+1]='\0';return i;}
    inline void read(std::string &s){char ch=getchar();s.clear();while(ch==' '||ch=='\n'||ch=='\r') ch=getchar();while(!(ch==' '||ch=='\n'||ch=='\r'||ch==EOF)) s.push_back(ch),ch=getchar();}
    template<typename _Tp,typename ...Args>inline void read(_Tp &x,Args &...args){read(x),read(args...);}
    /*---output---*/
    inline void write(const unsigned long long x){if(x<10) putchar(x+'0');else write(x/10),putchar(x%10+'0');}
    inline void write(const long long x){unsigned long long t=x;if(x<0)putchar('-'),t=-x;write(t);}
    inline void write(const unsigned int x){if(x<10) putchar(x+'0');else write(x/10),putchar(x%10+'0');}
    inline void write(const int x){unsigned int t=x;if(x<0)putchar('-'),t=-x;write(t);}
    inline void write(const char x){putchar(x);}
    inline void write(const char *s){while(*s!='\0'&&*s!=EOF)putchar(*s),s++;}
    template<typename _Tp,typename ...Args>inline void write(_Tp x,Args ...args){write(x),write(args...);}
}
using namespace wyzfastio;
using ll=long long;
using ld=long double;
//#define int ll
using pii=pair<int,int>;
const int N=2e6+5,mod=1e9+7;
inline ll qpow(ll a,ll b,ll M=mod){ll res=1;while(b){if(b&1)res=1ll*res*a%M;b>>=1,a=1ll*a*a%M;}return res;}
inline ll rd(ll x,ll M=mod){return x>=M?x-M:x;}
inline ll pr(ll x,ll M=mod){return x<0?x+M:x;}
vector<int> g[N],t[N];
int siz[N],son[N],fa[N],dep[N],dfn[N],top[N],dcnt;
void dfs1(int u)
{
    siz[u]=1;
    for(int v:g[u]) if(v!=fa[u])
    {
        fa[v]=u,dep[v]=dep[u]+1,dfs1(v),siz[u]+=siz[v];
        if(!son[u]||siz[son[u]]<siz[v]) son[u]=v;
    }
}
void dfs2(int u,int tp)
{
    dfn[u]=++dcnt;
    top[u]=tp;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(int v:g[u]) if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
int lca(int u,int v)
{
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    return u;
}
int n,f[4][N][4],s[N],p[N],dp[N][4],tmp[4];
void dpdfs(int u)
{
    for(int k=0;k<4;k++) dp[u][k]=1;
    if(p[u]) for(int k=0;k<4;k++) if((k>>1)!=s[u]) dp[u][k]=0;
    for(int v:t[u])
    {
        dpdfs(v);
        for(int st=0;st<4;st++)
            for(int sst=0;sst<4;sst++)
            {
                if(p[u]&&(sst>>1)!=s[u]) continue;
                tmp[sst]=(tmp[sst]+1ll*dp[v][st]*f[st][dep[v]-dep[u]][sst])%mod;
            }
        for(int k=0;k<4;k++) dp[u][k]=1ll*dp[u][k]*tmp[k]%mod,tmp[k]=0;
    }
}
int solve(vector<pii> &v)
{
    ll res=n-1,ans=0;
    vector<int> vv;
    vv.push_back(1);
    for(auto e:v) p[e.first]=1,s[e.first]=e.second,vv.push_back(e.first);
    sort(vv.begin(),vv.end(),[&](int x,int y){return dfn[x]<dfn[y];});
    for(int i=1,l=(int)vv.size();i<l;i++) vv.push_back(lca(vv[i-1],vv[i]));
    sort(vv.begin(),vv.end(),[&](int x,int y){return dfn[x]<dfn[y];});
    vv.erase(unique(vv.begin(),vv.end()),vv.end());
    for(int i=1,l=(int)vv.size();i<l;i++) t[lca(vv[i],vv[i-1])].push_back(vv[i]);
    for(int u:vv) for(int v:t[u]) res-=dep[v]-dep[u];
    dpdfs(1);
    for(int k=0;k<4;k++) if(!(k&1)) ans=(ans+dp[1][k])%mod;
    for(int i:vv)
    {
        t[i].clear(),s[i]=0,p[i]=0;
        for(int k=0;k<4;k++) dp[i][k]=0;
    }
    return 1ll*ans*qpow(2,res)%mod;
}
//bool Med;
signed main()
{
//  cerr<<"Memory Size: "<<abs((&Med)-(&Mst))/1024.0/1024<<" MB\n";
//  ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
//  freopen("in.in","r",stdin);
//  freopen("out.out","w",stdout);
    read(n);
    for(int i=1,u,v;i<n;i++) read(u,v),g[u].push_back(v),g[v].push_back(u);
    for(int st=0;st<4;st++)
    {
        f[st][0][st]=1;
        for(int i=0;i<n;i++)
            for(int s=0;s<4;s++)
                for(int j=0;j<2;j++) if((s<<1)+j!=5)
                {
                    int t=((s<<1)|j)&3;
                    f[st][i+1][t]=rd(f[st][i+1][t]+f[st][i][s]);
                }
    }
    dfs1(1),dfs2(1,1);
    int t,q;ll sum=0;
    read(t,q);
    vector<pii> vec;
    for(int i=1;i<=t;i++)
    {
        int k;read(k);vec.clear();
        for(int i=1,x,y;i<=k;i++) read(x,y),vec.push_back({x,y});
        ll ans=solve(vec);sum^=ans;
        if(i%q==0) write(sum,'\n'),sum=0;
    }
#ifdef usefio
    flush();
#endif
    return 0;
}