P3398 题解

· · 个人记录

思路

设一条路径为 (a,b),另一条为 (c,d),两条路径如果相交,那么有 lca(a,b)(c,d) 上或 lca(c,d)(a,b)上,这个东西画一下图就能理解,但是我不知道怎么证明

可以这么理解:
首先已知对于任意两个点 x,ylca(x,y) 一定在 xy 的路径上。
先确定了一条路径 (a,b) 和路径外的一点 c,还有一个点 d 未知。要使 (a,b)(c,d) 相交,就要使 lca(c,d) 成为路径 (a,b) 上的一点。

那么如何判断一个点是否在某条路径上呢?
设判断 a 是否在 (c,d) 上,如果 dis_{c,a} + dis_{a,d} = dis_{c,d},那么 a 就在 (c,d)
另外显然,dis_{a,b} = dep_a + dep_b - 2 \cdot dep_{lca(a,b)}dep 是节点的深度)

代码

#include <bits/stdc++.h>
using namespace std;
namespace Main
{
    const int maxn=100005;
    int n,q;
    int head[maxn<<1];
    struct EDGE
    {
        int to,nxt;
    }edge[maxn<<1];
    int cnt;
    inline void add(int u,int to)
    {
        edge[++cnt].to=to;
        edge[cnt].nxt=head[u];
        head[u]=cnt;
    }
    int dep[maxn];
    int fa[18][maxn];
    void dfs(int u,int _fa)
    {
        dep[u]=dep[_fa]+1;
        fa[0][u]=_fa;
        for(int i=1;i<=17;i++)
        {
            fa[i][u]=fa[i-1][fa[i-1][u]];
        }
        for(int i=head[u];i;i=edge[i].nxt)
        {
            int to=edge[i].to;
            if(to==_fa)continue;
            dfs(to,u);
        }
    }
    inline int lca(int a,int b)
    {
        if(dep[a]<dep[b])swap(a,b);
        for(int i=17;i>=0;i--)
        {
            if(dep[fa[i][a]]>=dep[b])
            {
                a=fa[i][a];
            }
        }
        if(a==b)return a;
        for(int i=17;i>=0;i--)
        {
            if(fa[i][a]!=fa[i][b])
            {
                a=fa[i][a];
                b=fa[i][b];
            }
        }
        return fa[0][a];
    }
    inline int dis(int a,int b)
    {//计算 a,b 两点的距离 
        int _lca=lca(a,b);
        return dep[a]+dep[b]-2*dep[_lca];
    }
    inline bool check(int a,int b,int c)
    {//判断 c 是否在 a,b 的链上 
        if(dis(a,c)+dis(c,b)==dis(a,b))
        {
            return 1;
        }
        return 0;
    }
    void main()
    {
        scanf("%d%d",&n,&q);
        for(int i=1;i<n;i++)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            add(u,v);
            add(v,u);
        }
        dfs(1,0);
        for(int i=1;i<=q;i++)
        {
            int a,b,c,d;
            scanf("%d%d%d%d",&a,&b,&c,&d);
            int _lca_ab=lca(a,b);
            int _lca_cd=lca(c,d);
            if(check(a,b,_lca_cd)||check(c,d,_lca_ab))
            {
                printf("Y\n");
            }
            else printf("N\n");
        }
    }
}
int main()
{
    Main::main();
    return 0;
}