题解:P4551 最长异或路径

· · 题解

题目链接

https://www.luogu.com.cn/problem/P4551

分析

首先我们有如下结论:

设两个节点到根节点的路径异或值为 x_1,x_2,则这两个节点之间的路径异或值为 x_1 \operatorname{xor} x_2

因此可以先求出每个节点到根节点的路径异或值,那么问题就转化成了:从 n 个整数中选两个进行异或运算,得到的结果最大是多少。

考虑使用字典树。可以将每个节点到根节点的路径异或值的二进制形式放进字典树里,答案即为字典树中与每个异或值取反后的结果匹配程度最高的值中的最大值。

细节内容见代码注释。

Code

#include<bits/stdc++.h>
#define i64 long long
#define max(x,y) ((x)>(y)?(x):(y))

using namespace std;

constexpr int N=1e5+5;
int n;

namespace Trie{
    int tot_trie;
    struct Node_trie{int son[2];}trie[int(3.1e6+5)];
    void insert_trie(int x){
        int p=0;
        for(int i=30;i>=0;--i){
            int num=x>>i&1;//二进制下从右往左数第 i+1 位的值 
            if(trie[p].son[num]==0)trie[p].son[num]=++tot_trie;
            p=trie[p].son[num];
        }
    }
    int query_trie(int x1){
        //按位取反(~ 运算符会把符号位也取反) 
        int x2=x1;
        for(int i=0;i<=30;++i)x2^=(1<<i);

        int p=0,res=0;
        for(int i=30;i>=0;--i){
            int num=x2>>i&1;
            //没有匹配的,只能走另一条路 
            if(trie[p].son[num]==0)
                p=trie[p].son[num^1];
            //匹配上了 
            else
                p=trie[p].son[num],
                res+=(1<<i);
        }
        return res;
    }
}

using namespace Trie;

namespace Graph{
    int tot_edge,hd[N];
    struct Edge{int to,val,lst;}g[N*2];
    void add_edge(int u,int v,int w){
        g[++tot_edge]=Edge{v,w,hd[u]};
        hd[u]=tot_edge;
    }

    //求每个节点到根节点的路径异或值 
    int xor_sum[N];
    void dfs(int x,int fa){
        //放进字典树里 
        insert_trie(xor_sum[x]);

        for(int i=hd[x];~i;i=g[i].lst)
            if(g[i].to!=fa){
                xor_sum[g[i].to]=xor_sum[x]^g[i].val;
                dfs(g[i].to,x);
            }
        return;
    }
}

using namespace Graph;

int main(){
    ios::sync_with_stdio(false),
    cin.tie(nullptr),cout.tie(nullptr);

    memset(hd,-1,sizeof hd);

    cin>>n;
    for(int i=1;i<n;++i){
        int x1,x2,x3;
        cin>>x1>>x2>>x3;

        add_edge(x1,x2,x3),add_edge(x2,x1,x3);
    }

    dfs(1,0);

    int ans=0;
    for(int i=1;i<=n;++i)
        ans=max(ans,query_trie(xor_sum[i]));

    cout<<ans<<endl;
    return 0;
}