[DS记录]P4565 [CTSC2018]暴力写挂

· · 个人记录

题意 : 给出两棵有根树T_1,T_2,标号对应,有边权,可正可负。

对于点对(u,v),贡献为T_1.dep(u)+T_1.dep(v)-T_1.lcadep(u,v)-T_2.lcadep(u,v)

求点对的最大贡献。

------------ # 解法1 : 点分治+虚树 考虑 $T_1.dis(u,v)+T_2.dis(u,v)+(T_2.dep(u)-T_1.dep(u))+(T_2.dep(v)-T_1.dep(v))

不难发现这是原定义贡献的两倍

忽视dep的性质,设w[u]=(T_2.dep(u)-T_1.dep(u))

现在变成了求T_1.dis(u,v)+T_2.dis(u,v)+w[u]+w[v]的最小值。

T_1上点分治,考虑跨越分治中心的路径的贡献。

这时,贡献是T_2.dis(u,v)+w[u]+w[v]+dep[u]+dep[v](以重心为根)

我们只需定义w'[u]=w[u]+dep[u],建立虚树就变成了子问题。

现在问题是求T.dis(u,v)+w[u]+w[v]的最大值。

问题在于,对于来自同一子树内的两个点,w[u]+w[v]并不等于dis(u,v),这样的贡献不合法。

我们根据子树来染色,钦定不同颜色才能贡献。注意建虚树时添加的辅助点不能贡献,可以把w置为-INF.

考虑对每个节点求出子树内dep+w的最大点和次大点(异色),必须来自不同子树。

合起来就是该点子树内最长路。

DP即可,注意每个点要保留不同颜色的最大值和次大值,因为最大值可能被碰掉,需要一个补刀的。

来自不同子树这一条件,可以在依次dfs儿子的时候,先check贡献,再把自身加入。

复杂度 O(n\log^2n) ,瓶颈在建虚树。

发现具体时间瓶颈居然在线性的vector建图&释放,不得不写了邻接链表。

比较卡常。

#include<algorithm>
#include<cstdio>
#include<vector>
#define pb push_back
#define ll long long
#define INF (1ll<<60)
#define MaxN 370000
using namespace std;
namespace IO{
  const unsigned int __=1<<16;
  static char Ch[__],*_S=Ch,*_T=Ch;
  inline char getc()
  {return((_S==_T)&&(_T=(_S=Ch)+fread(Ch,1,__,stdin),_S==_T)?0:*_S++);}
  inline int read(){
    int X=0;char ch=0,fl=0;
    while(ch<48||ch>57)ch=getchar(),fl|=(ch=='-');
    while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
    return fl ? -X:X;
  }
}using namespace IO;
#define l l2
struct Line
{int t,nxt;ll l;}l[MaxN<<1];
int tl,fir[MaxN];
void clr(int n)
{for (int i=1;i<=n;i++)fir[i]=0;tl=0;}
void adl(int f,int t,ll len){
  l[++tl]=(Line){t,fir[f],len};fir[f]=tl;
  l[++tl]=(Line){f,fir[t],len};fir[t]=tl;
}
ll ws[MaxN],w[MaxN],ans=-INF;
int col[MaxN];
struct Data{
  ll x,x2;int c;
  inline void chk(const Data &t){
    if (t.c==c){
      x=max(x,t.x);
      x2=max(x2,t.x2);
    }else {
      if (t.x>=x){
        x2=max(x,t.x2);
        c=t.c;x=t.x;
      }else x2=max(x2,t.x);
    }
  }
}ff[MaxN];
#define f ff
void dfsv(int u,int fa)
{
  f[u].x=w[u];f[u].c=col[u];f[u].x2=-INF;
  for (int i=fir[u],v;i;i=l[i].nxt)
    if ((v=l[i].t)!=fa){
      dfsv(v,u);
      f[v].x+=l[i].l;f[v].x2+=l[i].l;
      if (f[u].c==f[v].c)
        ans=max(ans,max(f[u].x+f[v].x2,f[u].x2+f[v].x));
      else ans=max(ans,f[u].x+f[v].x);
      f[u].chk(f[v]);
    }
}
#undef f
struct Node
{int f,tf,siz,p,dep;}b[MaxN];
int tim,in[MaxN],out[MaxN];
ll d1[MaxN],d2[MaxN];
void pfs1(int u)
{
  b[u].siz=1;
  ws[u]=d1[u]-d2[u];
  for (int i=fir[u],v;i;i=l[i].nxt)
    if (!b[v=l[i].t].siz){
      b[v].dep=b[b[v].f=u].dep+1;
      d2[v]=d2[u]+l[i].l;
      pfs1(v);
      b[u].siz+=b[v].siz;
      if (b[v].siz>b[b[u].p].siz)
        b[u].p=v;
    }
}
void pfs2(int u,int tf)
{
  in[u]=++tim;b[u].tf=tf;
  if (b[u].p)
    for (int i=fir[u],v;i;i=l[i].nxt)
      if (!b[v=l[i].t].tf)
        pfs2(v,v==b[u].p ? tf : v);
  out[u]=tim;
}
int lca(int x,int y)
{
  while(b[x].tf!=b[y].tf){
    if (b[b[x].tf].dep<b[b[y].tf].dep)swap(x,y);
    x=b[b[x].tf].f;
  }return b[x].dep<b[y].dep ? x:y;
}
#undef l
int tn,e[MaxN],ef,tp[MaxN],st[MaxN];
struct Point{int x,c;}p[MaxN];
bool cmp(const Point &A,const Point &B)
{return in[A.x]<in[B.x];}
vector<int> g[MaxN],l[MaxN];
void dfs(int u)
{
  p[++tn]=(Point){u,e[u]=ef};
  for (int i=0,v;i<g[u].size();i++)
    if (e[v=g[u][i]]<ef)
      {d1[v]=d1[u]+l[u][i];dfs(v);}
}
#define ES 1000000000
void calc(int u)
{
  static bool vis[MaxN];
  d1[u]=0;p[tn=1]=(Point){u,-1};
  for (int i=0,v;i<g[u].size();i++)
    if (e[v=g[u][i]]<ES)
      {ef++;d1[v]=l[u][i];dfs(v);}
  sort(p+1,p+tn+1,cmp);
  for (int i=1;i<=tn;i++)vis[p[i].x]=1;
  int top=tn;
  for (int i=1,u;i<top;i++)
    if (!vis[u=lca(p[i].x,p[i+1].x)]){
      p[++tn]=(Point){u,-2};
      vis[u]=1;d1[u]=-INF;
    }
  sort(p+1,p+tn+1,cmp);top=0;
  for (int i=1,u;i<=tn;i++){
    tp[u=p[i].x]=i;
    w[i]=ws[u]+d1[u];
    col[i]=p[i].c;
    while(top>1&&out[u]>out[st[top]]){
      adl(tp[st[top-1]],tp[st[top]],d2[st[top]]-d2[st[top-1]]);
      top--;
    }st[++top]=u;
  }while(top>1){
    adl(tp[st[top-1]],tp[st[top]],d2[st[top]]-d2[st[top-1]]);
    top--;
  }dfsv(1,0);clr(tn);
  for (int i=1;i<=tn;i++)vis[p[i].x]=0;
}
int t[MaxN],ms[MaxN],siz[MaxN];
void pfs(int u,int fa)
{
  e[u]=ef;ms[u]=0;
  siz[t[++tn]=u]=1;
  for (int i=0,v;i<g[u].size();i++)
    if (e[v=g[u][i]]<ef){
      pfs(v,u);
      siz[u]+=siz[v];
      ms[u]=max(ms[u],siz[v]);
    }
}
int grt(int u)
{
  tn=0;ef++;pfs(u,0);
  int rt=0;
  for (int i=1;i<=tn;i++){
    ms[t[i]]=max(ms[t[i]],tn-siz[t[i]]);
    if (ms[t[i]]<ms[rt])rt=t[i];
  }return rt;
}
void solve(int u)
{
  if (siz[u]==1)return ;
  e[u]=ES;calc(u);
  for (int i=0,v;i<g[u].size();i++)
    if (e[v=g[u][i]]<ES)
      solve(grt(v));
}
int main()
{
  int n=read();
  for (int i=1,fr,to,len;i<n;i++){
    fr=read();to=read();len=read();
    g[fr].pb(to);l[fr].pb(len);
    g[to].pb(fr);l[to].pb(len);
  }ef++;dfs(1);
  for (int i=1,fr,to;i<n;i++){
    fr=read();to=read();
    adl(fr,to,read());
  }b[1].dep=1;pfs1(1);pfs2(1,1);clr(n);
  ms[0]=n+1;solve(grt(1));
  for (int i=1;i<=n;i++)ans=max(ans,ws[i]<<1);
  printf("%lld",ans>>1);
}

解法2 : 边分树合并

上一种做法复杂度太丑所以需要卡常,我们寻求更加有理有据的作法。

边分树合并,可以看做边分治和链分治的结合。

在每次边分的时候,记录点被分到了哪边,表示为01序列,能够建立一棵01Trie,且深度是严格O(\log n)的。

类似线段树合并一样跑边分树合并,复杂度是 O(n\log n)

我们在 T1 上边分治预处理,建立边分树。然后在 T2dfs 遍历 LCA ,依次合并边分树,并且计算贡献。

由于 T_2.lca(u,v)已经被确定了,我们要对付的就是 T_1.dep(u)+T_1.dep(v)-T_1.lcadep(u,v)

能转化为 \frac{1}{2}(T_1.dep(u)+T_1.dep(v)+T_1.dis(u,v))

回顾性质 : 边分树上的父亲边一定在任意两个儿子的简单路径上。

所以,如果 u,v 在边 (t1,t2) 的不同子树内, dis(u,v)=dis(u,t1)+dis(t2,v)

这样,在每个 \rm Trie 节点(代表着一条边)上分别维护左子树内的 dis(u,t1)+dep[u] ,右子树内的 dis(t2,v)+dep[v]

这里注意,\rm Trie 不同层级维护的东西是不同的,不需要pushup,直接对应位置取\max即可。

一开始边分治的时候可以顺便求出每个点到分治中心的距离,这样 dis(u,t1) 等就可以 O(1) 取用了。

边分树合并计算贡献的时候,我们会遍历两棵边分树(Trie)的重叠部分。

对于一个重叠的点对 a,b ,是这样贡献的 : a的左儿子 & b的右儿子 ; a 的右儿子 & b的左儿子。

这些都能够通过合并已经维护过的信息来求得。

这样复杂度就是严格 O(n\log n) 的了。

不知道是边分治自带双倍常数,还是我写丑了,并没有明显的速度优势。

#include<algorithm>
#include<cstdio>
#define ll long long
#define INF (1ll<<60)
#define MaxN 740000
#define forG(u) for (int i=fir[u],v;i;i=g[i].nxt)
using namespace std;
namespace IO{
  const unsigned int __=1<<17;
  static char Ch[__],*_S=Ch,*_T=Ch;
  inline char getc()
  {return((_S==_T)&&(_T=(_S=Ch)+fread(Ch,1,__,stdin),_S==_T)?0:*_S++);}
  inline int read(){
    int X=0;char ch=0,fl=0;
    while(ch<48||ch>57)ch=getc(),fl|=(ch=='-');
    while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getc();
    return fl ? -X:X;
  }
}using namespace IO;
struct Line
{int t,nxt,l;}g[MaxN<<1];
int tl=1,fir[MaxN];
void adl(int f,int t,int len){
  g[++tl]=(Line){t,fir[f],len};fir[f]=tl;
  g[++tl]=(Line){f,fir[t],len};fir[t]=tl;
}
int tot,st[MaxN],siz[MaxN],ms[MaxN],e[MaxN],ef;
void clear(int n)
{for (int i=1;i<=n;i++)e[i]=fir[i]=0;tl=1;}
void pfs(int u)
{
  e[u]=ef;siz[st[++tot]=u]=1;
  forG(u) if (e[v=g[i].t]<ef)
    {pfs(v);siz[u]+=siz[v];}
}
int grt(int u)
{
  tot=0;ef++;pfs(u);
  int rt=0;
  for (int i=1;i<=tot;i++){
    ms[st[i]]=max(siz[st[i]],tot-siz[st[i]]);
    if (ms[st[i]]<=ms[rt])rt=st[i];
  }return rt;
}
struct TrNode
{int f,dep;bool tr;}b[MaxN<<1];
int tn2;
ll d[36][MaxN],*dis;
void dfs(int u)
{
  e[u]=ef;
  forG(u) if (e[v=g[i].t]<ef){
    dis[v]=dis[u]+g[i].l;
    dfs(v);
  }
}
int solve(int u,int dep)
{
  int v,tp=0;
  for (int i=fir[u];i;i=g[i].nxt)
    if (siz[g[i].t]>siz[u])
      {v=g[tp=i].t;break;}
  if (!tp){b[u].dep=dep;return u;}
  dis=d[dep];ef++;dfs(u);
  g[tp].t=g[tp^1].t=0;
  int pos=++tn2;b[pos].dep=dep;
  b[u=solve(grt(u),dep+1)].tr=0;b[u].f=pos;
  b[v=solve(grt(v),dep+1)].tr=1;b[v].f=pos;
  return pos;
}
ll dep1[MaxN],dep2[MaxN];
void pfs2(int u)
{
  e[u]=1;
  forG(u) if (!e[v=g[i].t]){
    dep1[v]=dep1[u]+g[i].l;
    pfs2(v);
  }
}
struct Node
{int l,r;ll x;}a[MaxN*19];
int tn,rt[MaxN];
int build(const int u)
{
  int p=u;
  while(1){
    a[++tn].x=d[b[p].dep-1][u]+dep1[u];
    if (!b[p].f)break;
    if (b[p].tr)a[tn+1].r=tn;
    else a[tn+1].l=tn;
    p=b[p].f;
  }return tn;
}
ll ans=-INF;
int merge(int x,int y,ll buf)
{
  if (!x||!y)return x|y;
  ll sav=buf+max(a[a[x].l].x+a[a[y].r].x,a[a[x].r].x+a[a[y].l].x);
  ans=max(ans,sav);
  a[x].x=max(a[x].x,a[y].x);
  a[x].l=merge(a[x].l,a[y].l,buf);
  a[x].r=merge(a[x].r,a[y].r,buf);
  return x;
}
void dfs2(int u)
{
  e[u]=1;rt[u]=build(u);
  forG(u) if (!e[v=g[i].t]){
    dep2[v]=dep2[u]+g[i].l;
    dfs2(v);
    rt[u]=merge(rt[u],rt[v],-2*dep2[u]);
  }
}
struct SLine
{int f,t,l;}sl[MaxN];
int las[MaxN],tn3;
void cre(int u,int v,int len){
  if (siz[v]>siz[u])swap(u,v);
  if (!las[u]){las[u]=u;adl(u,v,len);}
  else {
    adl(++tn3,las[u],0);
    adl(las[u]=tn3,v,len);
  }
}
int main()
{
  int n=tn3=read();
  for (int i=1,fr,to;i<n;i++){
    fr=read();to=read();
    sl[i]=(SLine){fr,to,read()};
    adl(fr,to,sl[i].l);
  }pfs2(1);ef=2;pfs(1);clear(n);
  for (int i=1;i<n;i++)
    cre(sl[i].f,sl[i].t,sl[i].l);
  tn2=tn3;
  ms[0]=e[0]=(1<<30);solve(grt(1),1);
  clear(tn3);
  for (int i=1,fr,to;i<n;i++){
    fr=read();to=read();
    adl(fr,to,read());
  }a[0].x=-INF;dfs2(1);
  ans>>=1;
  for (int i=1;i<=n;i++)ans=max(ans,dep1[i]-dep2[i]);
  printf("%lld",ans);
}