题解:P4775 [NOI2018] 情报中心

· · 题解

题解:P4775 [NOI2018] 情报中心

标程是人能写出来的吗。

简化题意

给你一棵有边权的树,然后给你 m 个点对,选取一个点对有相应的代价,每个点对的收益为两点在树上经过的最短路径权值和。

需要你选出两个点对,使两个点对的最短路径有交,问收益减去代价的最大值。

注:重复的边权值只算一次。

思路

遇到这种复杂的题,我们先将问题分割成若干个子问题:

  1. 如何找到合法点对?

  2. 如何计算两个合法点对的收益以及寻找最大值?

合法点对,我们先来画几张图来观察一下合法点对的性质:

这几种情况貌似就可以说明所有的情况了。

不难发现,两条线如果有交,那么其中一条的 lca 必然在另外一条的路径上。

所以我们的方向就逐渐转向 lca。

如何解决收益呢?重合的部分可能因为相交的形态不同而计算方式不同,所以我们果断放弃计算重合部分的方法。继续观察上图,不难发现:

2 \cdot ans=val=dis(u_1,v_1)+dis(u_2,v_2)+dis(u_1,u_2)+dis(v_1,v_2)-2 \cdot cost_1-2 \cdot cost_2

这好发现吗。

这式子有用吗?我们来化一化,首先我们先将只与自己有关的点对写在一起,因为它们可以预处理出来,设:

w(u)=dis(u,v)-2 \cdot cost

所以:

val=w(u_1)+w(u_2)+dis(u_1,u_2)+dis(v_1,v_2)

我们发现 dis(u_1,u_2) 可以化为只与 u 有关的式子:

dis(u_1,u_2)=dis(u_1,rt)+dis(u_2+rt)-dis(lca,rt)

然后我们发现这不是也可以预处理吗,我们又设:

w(u)=dis(u,v)+dis(u,rt)-2 \cdot cost

所以:

val=w(u_1)+w(u_2)-2 \cdot dis(lca(u_1,u_2),rt)+dis(v_1,v_2)

但是此时 -2 \cdot dis(lca(u_1,u_2),rt)+dis(v_1,v_2) 这个式子已经不再好化了,我们上文提到过由于我们需要保证选的点对合法,最简单的方向就是向 lca 靠。所以如果此时我们枚举 lca 那么-2 \cdot dis(lca(u_1,u_2),rt) 就会变为一个我们已知的定值,然后此时我们就只需要维护 dis(v_1,v_2) 了。

那这不就是树上最远点对了吗?你爱咋做咋做,像啥淀粉树,线段树呀维护一下就可以了。

我们这里考虑线段树合并的做法,对于每一个点开一颗线段树,树上维护最终组合对即可。

当然我们需要保证点对的合法性,所以考虑在某些位置删除掉不会做贡献的点的贡献。这里可以考虑树上差分,在询问时将点的贡献拍入该点的线段树上,在其到 lca 路径中的 lca 的儿子节点打上删除标记,到了就删就行了。

至于为什么要在儿子结点删除,看一下面的图吧。

这种组合显然不合法,但是如果我们在 lca 处删除 uv 的贡献,这种情况就会被统计入答案,所以我们要在 lca 的儿子结点删除。

但是还有问题,我们的 uv 不都像我们画的图一样有序,这是考虑对四种不同的组合方式枚举即可。

代码

喜滋滋的代码时间。

写完后才发现倍增求 lca 似乎更好写?

#include<bits/stdc++.h>
namespace fast_IO {
#define IOSIZE 1000000
    char ibuf[IOSIZE], obuf[IOSIZE], *p1 = ibuf, *p2 = ibuf, *p3 = obuf;
#define getchar() ((p1==p2)and(p2=(p1=ibuf)+fread(ibuf,1,IOSIZE,stdin),p1==p2)?(EOF):(*p1++))
#define putchar(x) ((p3==obuf+IOSIZE)&&(fwrite(obuf,p3-obuf,1,stdout),p3=obuf),*p3++=x)
#define isdigit(ch) (ch>47&&ch<58)
#define isspace(ch) (ch<33)
    template<typename T> inline T read() { T s = 0; int w = 1; char ch; while (ch = getchar(), !isdigit(ch) and (ch != EOF)) if (ch == '-') w = -1; if (ch == EOF) return false; while (isdigit(ch)) s = s * 10 + ch - 48, ch = getchar(); return s * w; }
    template<typename T> inline bool read(T &s) { s = 0; int w = 1; char ch; while (ch = getchar(), !isdigit(ch) and (ch != EOF)) if (ch == '-') w = -1; if (ch == EOF) return false; while (isdigit(ch)) s = s * 10 + ch - 48, ch = getchar(); return s *= w, true; }
    template<typename T> inline void print(T x) { if (x < 0) putchar('-'), x = -x; if (x > 9) print(x / 10); putchar(x % 10 + 48); }
    inline bool read(char &s) { while (s = getchar(), isspace(s)); return true; }
    inline bool read(char *s) { char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) *s++ = ch, ch = getchar(); *s = '\000'; return true; }
    inline void print(char x) { putchar(x); }
    inline void print(char *x) { while (*x) putchar(*x++); }
    inline void print(const char *x) { for (int i = 0; x[i]; i++) putchar(x[i]); }
    inline bool read(std::string& s) { s = ""; char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) s += ch, ch = getchar(); return true; }
    inline void print(std::string x) { for (int i = 0, n = x.size(); i < n; i++) putchar(x[i]); }
    inline bool read(bool &b) { char ch; while(ch=getchar(), isspace(ch)); b=ch^48; return true; }
    inline void print(bool b) { putchar(b+48); }
    template<typename T, typename... T1> inline int read(T& a, T1&... other) { return read(a) + read(other...); }
    template<typename T, typename... T1> inline void print(T a, T1... other) { print(a), print(other...); }
    struct Fast_IO { ~Fast_IO() { fwrite(obuf, p3 - obuf, 1, stdout); } } io;
    template<typename T> Fast_IO& operator >> (Fast_IO &io, T &b) { return read(b), io; }
    template<typename T> Fast_IO& operator << (Fast_IO &io, T b) { return print(b), io; }
#define cout io
#define cin io
#define endl '\n'
} using namespace fast_IO;
using namespace std;
#define lson t[k].ls
#define rson t[k].rs
#define int long long
//对于两个点对,如果合法,那么其贡献为val=dis(u_1,v_1)+dis(u_2,v_2)+dis(u_1,u_2)+dis(v_1,v_2)-2cost_1-2cost_2
//继续转化我们设一点对中一个点的贡献为w(u)=dis(u,v)-2cost,对于v是一样的
//所以val=w(u_1)+w(u_2)+dis(u_1,u_2)+dis(v_1,v_2)
//此时如果我们加入v_1,v_2的lca
//那么式子为val=w(u_1)+w(u_2)+dis(u_1,u_2)+dis(v_1,rt)+dis(v_2,rt)-2*dis(lca,rt)
//我们再次定义设w(u)=dis(u,v)-2cost+dis(v,rt)
//则val=w(u_1)+w(u_2)+dis(u_1,u_2)-2*dis(lca,rt)
//那么对于每个点的w我们均可以预处理出来。
//而如果我们枚举lca,则2*dis(lca,rt)为定值,那么我们需要维护的就只有dis(u_1,u_2)了

const int N=1e6+15;
const int inf=1e18;

int h[N],to[N],ne[N],w[N],idx=0;
int siz[N],son[N],fa[N],topfa[N],dep[N];
int in[N],out[N],dis[N],tim;
queue<int> rub;
vector<int> del[N];
struct pair_point
{
    pair<int,int> u,v;
}ansp;
struct segment_tree
{
    int ls,rs;
    pair_point sum;
}t[N<<3];
int rt[N];
int n,m,T,cnt;
int ans;

void init()
{
    while(!rub.empty())
        rub.pop();
    for(int i=1;i<=n;i++)
    {
        del[i].clear();
        son[i]=fa[i]=0;
        dis[i]=dep[i]=0;
        fa[i]=topfa[i]=0;
        in[i]=out[i]=0;
        h[i]=-1;
        rt[i]=0;
    }
    for(int i=1;i<=cnt;i++)
        t[i].ls=t[i].rs=0,t[i].sum={{0,0},{0,0}};
    tim=cnt=idx=0;
    ans=-inf;
}

void add(int u,int v,int val)
{
    to[++idx]=v;
    ne[idx]=h[u];
    h[u]=idx;
    w[idx]=val;
}

void get_fa(int u,int f)
{
    fa[u]=f;
    siz[u]=1;
    dep[u]=dep[f]+1;
    in[u]=++tim;
    for(int i=h[u];i!=-1;i=ne[i])
    {
        int v=to[i];
        if(v==f)
            continue;
        dis[v]=dis[u]+w[i];
        get_fa(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]||!son[u])
            son[u]=v;
    }
    out[u]=tim;
}   

void get_topfa(int u,int topf)
{
    topfa[u]=topf;
    if(son[u])
        get_topfa(son[u],topf);
    for(int i=h[u];i!=-1;i=ne[i])
    {
        int v=to[i];
        if(v==son[u]||v==fa[u])
            continue;
        if(!topfa[v])
            get_topfa(v,v);
    }
}

pair<int,pair<int,int>> get_lca(int x,int y)
//返回<lca,<x->lca中lca的儿子 >,<y->lca 中lca的儿子>>
{
    bool fl=0;
    int resx=x;
    int resy=y;
    while(topfa[x]!=topfa[y])
    {
        if(dep[topfa[x]]<dep[topfa[y]])
            swap(x,y),swap(resx,resy),fl^=1;
        resx=topfa[x];
        x=fa[topfa[x]];
    }
    int lca;
    if(x==y)
        lca=x;
    else if(dep[x]<dep[y])
        lca=x,resy=son[x];
    else
        lca=y,resx=son[y];
    if(fl)
        swap(resx,resy);
    return {lca,{resx,resy}};
}

int new_node()
{
    if(!rub.empty())
    {
        int l=rub.front();
        rub.pop();
        return l;
    }
    else 
        return ++cnt;
}

bool check(pair<int,int> x,pair<int,int> y,int &maxx)
{
    if(x.first&&y.first)
    {
        int lca=get_lca(x.first,y.first).first;
        int res=dis[x.first]+dis[y.first]-2*dis[lca]+x.second+y.second;
        if(res>maxx)
        {
            maxx=res;
            return 1;
        }
        else 
            return 0;
    }
    return 0;
}

int merge_ans(pair_point &ans,pair_point x,pair_point y)
{
    if(!x.u.first&&!x.v.first)
    {
        ans=y;
        return -inf;
    }
    if(!y.u.first&&!y.v.first)
    {
        ans=x;
        return -inf;
    }
    int maxx=-inf;
    int res;
    int opt=0;
    if(check(x.u,y.u,maxx))
        opt=1;
    if(check(x.u,y.v,maxx))
        opt=2;
    if(check(x.v,y.u,maxx))
        opt=3;
    if(check(x.v,y.v,maxx))
        opt=4;
    res=maxx;
    if(check(x.u,x.v,maxx))
        opt=5;
    if(check(y.u,y.v,maxx))
        opt=6;
    if(opt==1)
        ans={x.u,y.u};
    if(opt==2)
        ans={x.u,y.v};
    if(opt==3)
        ans={x.v,y.u};
    if(opt==4)
        ans={x.v,y.v};
    if(opt==5)
        ans={x.u,x.v};
    if(opt==6)
        ans={y.u,y.v};
    return res;
}

void clear(int k)
{
    t[k].ls=t[k].rs=0;
    t[k].sum={{0,0},{0,0}};
    rub.push(k);
}

void push_up(int k)
{
    if(!lson||!rson)
        t[k].sum=t[lson+rson].sum;
    merge_ans(t[k].sum,t[lson].sum,t[rson].sum);
}

void update_add(int &k,int l,int r,int pos,pair<int,int> x)
{
    if(!k)
        k=new_node();
    if(l==r)
    {
        t[k].sum.u=x;
        return;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)
        update_add(lson,l,mid,pos,x);
    else 
        update_add(rson,mid+1,r,pos,x);
    push_up(k);
}

void update_del(int &k,int l,int r,int pos)
{
    if(!k)
        k=new_node();
    if(l==r)
    {
        t[k].sum.u={0,0};
        return;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)
        update_del(lson,l,mid,pos);
    else 
        update_del(rson,mid+1,r,pos);
    push_up(k);
}

int merge_tree(int x,int y)
{
    if(!x||!y)
        return x+y;
    merge_ans(t[x].sum,t[x].sum,t[y].sum);
    t[x].ls=merge_tree(t[x].ls,t[y].ls);
    t[x].rs=merge_tree(t[x].rs,t[y].rs);
    clear(y);
    return x;
}

void get_ans(int u)
{
    for(int i=h[u];i!=-1;i=ne[i])
    {
        int v=to[i];
        if(v==fa[u])
            continue;
        get_ans(v);
        ans=max(ans,merge_ans(ansp,t[rt[u]].sum,t[rt[v]].sum)-2*dis[u]);
        rt[u]=merge_tree(rt[u],rt[v]);
    }
    for(auto v:del[u])
        update_del(rt[u],1,m,v);
}

signed main()
{
    // freopen("center.in","r",stdin);
    // freopen("center.out","w",stdout);
    cin>>T;
    while(T--)
    {
        cin>>n;
        init();
        for(int i=1;i<n;i++)
        {
            int u,v,w;
            cin>>u>>v>>w;
            add(u,v,w);
            add(v,u,w);
        }
        get_fa(1,0);
        get_topfa(1,1);
        // int x,y;
        // cin>>x>>y;
        // cout<<get_dis(x,y);
        cin>>m;
        for(int i=1;i<=m;i++)
        {
            int u,v,cost;
            cin>>u>>v>>cost;
            pair<int,pair<int,int>> res=get_lca(u,v);
            int lca=res.first;
            int w=dis[u]+dis[v]-2*dis[lca]-2*cost;
            int lu=res.second.first;
            int lv=res.second.second;
            // cout<<"e le mi "<<lca<<" "<<lu<<" "<<lv<<endl;
            if(u!=lca)
            {
                pair<int,int> val={v,w+dis[u]};
                ans=max(ans,merge_ans(ansp,{val,{0,0}},t[rt[u]].sum)-2*dis[u]);
                update_add(rt[u],1,m,i,val);
                del[lu].push_back(i);
            }
            if(v!=lca)
            {
                pair<int,int> val={u,w+dis[v]};
                ans=max(ans,merge_ans(ansp,{val,{0,0}},t[rt[v]].sum)-2*dis[v]);
                update_add(rt[v],1,m,i,val);
                del[lv].push_back(i);
            }
        }
        get_ans(1);
        if(ans==-inf)
            cout<<"F"<<endl;
        else
            cout<<ans/2<<endl;
    }
    return 0;
}