[DS记录]P5439 【XR-2】永恒

· · 个人记录

题意 : 给出两棵树 T_1,T_2 ,标号对应。

定义 lcadep(u,v)T_2u,v 两点的 \rm lca 的深度,也简记为 ld(u,v).

定义 T_1 上一条路径 P:u\leftrightarrow v 的价值为

\sum\limits_{a\in P,b\in P,a<b}ld(a,b)

T_1 所有无向路径的价值和,答案对 998244353 取模。

------------ 统计树上点对贡献。 考虑点分治,每次处理通过分治中心 $t$ 的所有路径的贡献。 令整颗原树以 $t$ 为根,设 $siz[u]$ 为 $u$ 点的子树大小。其求解可以借助点分治。 若 $(u,v)$ 不在同一个子树内,同时经过 $u,v$ 的路径总数即为 $siz[u]siz[v]$。 于是问题变成了 : $$\sum_{u}\sum_{v}ld(u,v)siz[u]siz[v]$$ 这可以在 $T_2$ 上建立虚树,然后可以树形 $\rm DP$ 计算。 利用归并和 $O(1)\ \rm LCA$ 建立虚树,则复杂度为 $O(n\log n+m\log m)$。 好难写啊qwq ```cpp #include<algorithm> #include<cstdio> #include<vector> #define pb push_back #define MaxN 300500 using namespace std; const int mod=998244353; int read(){ int X=0;char ch=0; while(ch<48||ch>57)ch=getchar(); while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar(); return X; } int n1,n2,ans; namespace Tree{ vector<int> g[MaxN]; int dfn[MaxN],out[MaxN],tim,dep[MaxN],sp[21][MaxN<<1]; bool cmp(int A,int B){return dfn[A]<dfn[B];} void pfs(int u){ sp[0][dfn[u]=++tim]=u; for (int i=0;i<g[u].size();i++){ dep[g[u][i]]=dep[u]+1; pfs(g[u][i]); sp[0][++tim]=u; }out[u]=tim; } int lg2[MaxN<<1]; inline int minp(int u,int v) {return dep[u]<dep[v] ? u:v;} int lca(int x,int y) { x=dfn[x];y=dfn[y]; if (x>y)swap(x,y); int k=lg2[y-x+1]; return minp(sp[k][x],sp[k][y-(1<<k)+1]); } void Init() { dep[0]=dep[1]=1;pfs(1); for (int j=0;(1<<(j+1))<=tim;j++) for (int i=1;i+(1<<j+1)-1<=tim;i++) sp[j+1][i]=minp(sp[j][i],sp[j][i+(1<<j)]); for (int i=2;i<=tim;i++)lg2[i]=lg2[i>>1]+1; } struct Line{int nxt,t;}l[MaxN]; int fir[MaxN],tl; inline void adl(int u,int v) {l[++tl]=(Line){fir[u],v};fir[u]=tl;} int w[MaxN],s[MaxN],ret; void dfs(int u,int fa) { s[u]=w[u]; for (int i=fir[u],v;i;i=l[i].nxt) {dfs(v=l[i].t,u);s[u]=(s[u]+s[v])%mod;} ret=(ret+1ll*s[u]*s[u]%mod*(dep[u]-dep[fa]))%mod; fir[u]=0;w[u]=0; } int stk[MaxN]; int calc(int *p,int m) { int top=tl=0;p[0]=1;p[m+1]=0; for (int i=0;i<=m;i++){ if (p[i]==p[i+1])continue; if (top&&out[stk[top]]<dfn[p[i]]){ int t=lca(stk[top],p[i]); while(top>2&&dep[stk[top-1]]>dep[t]) {adl(stk[top-1],stk[top]);top--;} if (t==stk[top-1]){adl(stk[top-1],stk[top]);top--;} else {adl(t,stk[top]);stk[top]=t;} }stk[++top]=p[i]; }while(top>1){adl(stk[top-1],stk[top]);top--;} ret=0;dfs(1,0); return ret; } }; vector<int> g[MaxN]; int _fa[MaxN],_siz[MaxN]; void _pfs(int u) { _siz[u]=1; for (int i=0,v;i<g[u].size();i++) if (!_siz[v=g[u][i]]){ _fa[v]=u;_pfs(v); _siz[u]+=_siz[v]; } } int siz[MaxN],mp[MaxN],st[MaxN],tn; bool vis[MaxN]; void pfs(int u,int fa) { siz[st[++tn]=u]=1;mp[u]=0; for (int i=0,v;i<g[u].size();i++) if ((v=g[u][i])!=fa&&!vis[v]){ pfs(v,u); siz[u]+=siz[v]; mp[u]=max(mp[u],siz[v]); } } int grt(int u) { int rt=tn=0;pfs(u,0); for (int i=1;i<=tn;i++){ int u=st[i]; mp[u]=max(mp[u],tn-siz[u]); if (mp[u]<mp[rt])rt=u; }return rt; } int rt,buf,tp[MaxN],w[21][MaxN],*tw; void dfs(int u,int fa) { tw[u]=(fa==_fa[u]) ? _siz[u] : n1-_siz[fa]; ans=(ans+1ll*buf*tw[u]%mod*(Tree::dep[Tree::lca(tp[rt],tp[u])]-1))%mod; for (int i=0,v;i<g[u].size();i++) if (!vis[v=g[u][i]]&&v!=fa) dfs(v,u); } int fa[MaxN],dep[MaxN]; void solve(int u) { tw=w[dep[u]=dep[fa[u]]+1]; vis[u]=1; for (int i=0,v;i<g[u].size();i++) if (!vis[v=g[u][i]]){ rt=u;buf=(u==_fa[v]) ? n1-_siz[v] : _siz[u]; dfs(v,u); } for (int i=0,v;i<g[u].size();i++) if (!vis[v=g[u][i]]){ fa[v=grt(v)]=u; solve(v); } } void input() { n1=read();n2=read(); for (int i=1;i<=n1;i++){ int v=read(); if (v){g[v].pb(i);g[i].pb(v);} } read(); for (int i=2;i<=n2;i++) Tree::g[read()].pb(i); scanf("%*s"); for (int i=1;i<=n1;i++)tp[i]=read(); } vector<int> sp[MaxN]; int p[MaxN]; bool cmp(int A,int B) {return Tree::dfn[tp[A]]<Tree::dfn[tp[B]];} void Init() { Tree::Init(); _pfs(1);mp[0]=n1;solve(grt(1)); for (int i=1;i<=n1;i++)p[i]=i; sort(p+1,p+n1+1,cmp); for (int i=1;i<=n1;i++) for (int u=p[i];u;u=fa[u]) sp[u].pb(p[i]); } int main() { input();Init(); ans=ans*2%mod; for (int u=1;u<=n1;u++){ if (sp[u].size()<=1)continue; for (int i=0;i<sp[u].size();i++) Tree::w[st[i+1]=tp[sp[u][i]]]=0; for (int i=0;i<sp[u].size();i++){ if (sp[u][i]==u)continue; int c=w[dep[u]][sp[u][i]]; Tree::w[st[i+1]]=(Tree::w[st[i+1]]+c)%mod; ans=(ans-1ll*c*c%mod*(Tree::dep[st[i+1]]-1))%mod; }ans=(ans+Tree::calc(st,sp[u].size()))%mod; for (int i=0;i<sp[u].size();i++) Tree::w[st[i+1]=tp[sp[u][i]]]=0; for (int i=0;i<sp[u].size();i++){ int c=w[dep[u]-1][sp[u][i]]; Tree::w[st[i+1]]=(Tree::w[st[i+1]]+c)%mod; ans=(ans+1ll*c*c%mod*(Tree::dep[st[i+1]]-1))%mod; }ans=(ans-Tree::calc(st,sp[u].size()))%mod; } ans=1ll*ans*(mod+1)/2%mod; ans=(ans+mod)%mod; printf("%d",ans); return 0; } ```