[DS记录]P5439 【XR-2】永恒
command_block
·
·
个人记录
题意 : 给出两棵树 T_1,T_2 ,标号对应。
定义 lcadep(u,v) 为 T_2 上 u,v 两点的 \rm lca 的深度,也简记为 ld(u,v).
定义 T_1 上一条路径 P:u\leftrightarrow v 的价值为
\sum\limits_{a\in P,b\in P,a<b}ld(a,b)
求 T_1 所有无向路径的价值和,答案对 998244353 取模。
------------
统计树上点对贡献。
考虑点分治,每次处理通过分治中心 $t$ 的所有路径的贡献。
令整颗原树以 $t$ 为根,设 $siz[u]$ 为 $u$ 点的子树大小。其求解可以借助点分治。
若 $(u,v)$ 不在同一个子树内,同时经过 $u,v$ 的路径总数即为 $siz[u]siz[v]$。
于是问题变成了 :
$$\sum_{u}\sum_{v}ld(u,v)siz[u]siz[v]$$
这可以在 $T_2$ 上建立虚树,然后可以树形 $\rm DP$ 计算。
利用归并和 $O(1)\ \rm LCA$ 建立虚树,则复杂度为 $O(n\log n+m\log m)$。
好难写啊qwq
```cpp
#include<algorithm>
#include<cstdio>
#include<vector>
#define pb push_back
#define MaxN 300500
using namespace std;
const int mod=998244353;
int read(){
int X=0;char ch=0;
while(ch<48||ch>57)ch=getchar();
while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
return X;
}
int n1,n2,ans;
namespace Tree{
vector<int> g[MaxN];
int dfn[MaxN],out[MaxN],tim,dep[MaxN],sp[21][MaxN<<1];
bool cmp(int A,int B){return dfn[A]<dfn[B];}
void pfs(int u){
sp[0][dfn[u]=++tim]=u;
for (int i=0;i<g[u].size();i++){
dep[g[u][i]]=dep[u]+1;
pfs(g[u][i]);
sp[0][++tim]=u;
}out[u]=tim;
}
int lg2[MaxN<<1];
inline int minp(int u,int v)
{return dep[u]<dep[v] ? u:v;}
int lca(int x,int y)
{
x=dfn[x];y=dfn[y];
if (x>y)swap(x,y);
int k=lg2[y-x+1];
return minp(sp[k][x],sp[k][y-(1<<k)+1]);
}
void Init()
{
dep[0]=dep[1]=1;pfs(1);
for (int j=0;(1<<(j+1))<=tim;j++)
for (int i=1;i+(1<<j+1)-1<=tim;i++)
sp[j+1][i]=minp(sp[j][i],sp[j][i+(1<<j)]);
for (int i=2;i<=tim;i++)lg2[i]=lg2[i>>1]+1;
}
struct Line{int nxt,t;}l[MaxN];
int fir[MaxN],tl;
inline void adl(int u,int v)
{l[++tl]=(Line){fir[u],v};fir[u]=tl;}
int w[MaxN],s[MaxN],ret;
void dfs(int u,int fa)
{
s[u]=w[u];
for (int i=fir[u],v;i;i=l[i].nxt)
{dfs(v=l[i].t,u);s[u]=(s[u]+s[v])%mod;}
ret=(ret+1ll*s[u]*s[u]%mod*(dep[u]-dep[fa]))%mod;
fir[u]=0;w[u]=0;
}
int stk[MaxN];
int calc(int *p,int m)
{
int top=tl=0;p[0]=1;p[m+1]=0;
for (int i=0;i<=m;i++){
if (p[i]==p[i+1])continue;
if (top&&out[stk[top]]<dfn[p[i]]){
int t=lca(stk[top],p[i]);
while(top>2&&dep[stk[top-1]]>dep[t])
{adl(stk[top-1],stk[top]);top--;}
if (t==stk[top-1]){adl(stk[top-1],stk[top]);top--;}
else {adl(t,stk[top]);stk[top]=t;}
}stk[++top]=p[i];
}while(top>1){adl(stk[top-1],stk[top]);top--;}
ret=0;dfs(1,0);
return ret;
}
};
vector<int> g[MaxN];
int _fa[MaxN],_siz[MaxN];
void _pfs(int u)
{
_siz[u]=1;
for (int i=0,v;i<g[u].size();i++)
if (!_siz[v=g[u][i]]){
_fa[v]=u;_pfs(v);
_siz[u]+=_siz[v];
}
}
int siz[MaxN],mp[MaxN],st[MaxN],tn;
bool vis[MaxN];
void pfs(int u,int fa)
{
siz[st[++tn]=u]=1;mp[u]=0;
for (int i=0,v;i<g[u].size();i++)
if ((v=g[u][i])!=fa&&!vis[v]){
pfs(v,u);
siz[u]+=siz[v];
mp[u]=max(mp[u],siz[v]);
}
}
int grt(int u)
{
int rt=tn=0;pfs(u,0);
for (int i=1;i<=tn;i++){
int u=st[i];
mp[u]=max(mp[u],tn-siz[u]);
if (mp[u]<mp[rt])rt=u;
}return rt;
}
int rt,buf,tp[MaxN],w[21][MaxN],*tw;
void dfs(int u,int fa)
{
tw[u]=(fa==_fa[u]) ? _siz[u] : n1-_siz[fa];
ans=(ans+1ll*buf*tw[u]%mod*(Tree::dep[Tree::lca(tp[rt],tp[u])]-1))%mod;
for (int i=0,v;i<g[u].size();i++)
if (!vis[v=g[u][i]]&&v!=fa)
dfs(v,u);
}
int fa[MaxN],dep[MaxN];
void solve(int u)
{
tw=w[dep[u]=dep[fa[u]]+1];
vis[u]=1;
for (int i=0,v;i<g[u].size();i++)
if (!vis[v=g[u][i]]){
rt=u;buf=(u==_fa[v]) ? n1-_siz[v] : _siz[u];
dfs(v,u);
}
for (int i=0,v;i<g[u].size();i++)
if (!vis[v=g[u][i]]){
fa[v=grt(v)]=u;
solve(v);
}
}
void input()
{
n1=read();n2=read();
for (int i=1;i<=n1;i++){
int v=read();
if (v){g[v].pb(i);g[i].pb(v);}
}
read();
for (int i=2;i<=n2;i++)
Tree::g[read()].pb(i);
scanf("%*s");
for (int i=1;i<=n1;i++)tp[i]=read();
}
vector<int> sp[MaxN];
int p[MaxN];
bool cmp(int A,int B)
{return Tree::dfn[tp[A]]<Tree::dfn[tp[B]];}
void Init()
{
Tree::Init();
_pfs(1);mp[0]=n1;solve(grt(1));
for (int i=1;i<=n1;i++)p[i]=i;
sort(p+1,p+n1+1,cmp);
for (int i=1;i<=n1;i++)
for (int u=p[i];u;u=fa[u])
sp[u].pb(p[i]);
}
int main()
{
input();Init();
ans=ans*2%mod;
for (int u=1;u<=n1;u++){
if (sp[u].size()<=1)continue;
for (int i=0;i<sp[u].size();i++)
Tree::w[st[i+1]=tp[sp[u][i]]]=0;
for (int i=0;i<sp[u].size();i++){
if (sp[u][i]==u)continue;
int c=w[dep[u]][sp[u][i]];
Tree::w[st[i+1]]=(Tree::w[st[i+1]]+c)%mod;
ans=(ans-1ll*c*c%mod*(Tree::dep[st[i+1]]-1))%mod;
}ans=(ans+Tree::calc(st,sp[u].size()))%mod;
for (int i=0;i<sp[u].size();i++)
Tree::w[st[i+1]=tp[sp[u][i]]]=0;
for (int i=0;i<sp[u].size();i++){
int c=w[dep[u]-1][sp[u][i]];
Tree::w[st[i+1]]=(Tree::w[st[i+1]]+c)%mod;
ans=(ans+1ll*c*c%mod*(Tree::dep[st[i+1]]-1))%mod;
}ans=(ans-Tree::calc(st,sp[u].size()))%mod;
}
ans=1ll*ans*(mod+1)/2%mod;
ans=(ans+mod)%mod;
printf("%d",ans);
return 0;
}
```