*P8817 题解

· · 个人记录

P8817 题解

先行吐槽:数据过水,一个重大的错误竟然得了 80 分。

题目大意

给定 n , m , k2 \sim n n - 1 个点的分数,和一张 n 个点 m 条边的无向图(无自环、重边)。你必须恰好选择四个不同于 1 且互不相同的点,并按一定顺序将其排列好,设四个点按顺序分别为 a , b , c , d,则要求 1 \rightarrow aa \rightarrow bb \rightarrow cc \rightarrow dd \rightarrow 1 这五段的距离均小于等于 k + 1(即除每段的起点和终点外,最多额外经过 k 个点)。要求最大化 a , b , c , d 四个点的分数之和。

题目思路

下面记 a , b , c , d 分别为第一、二、三、四个点,并且我们认为两点距离大于 k + 1 时这两点不可互相到达。

首先不能四个点全枚举,不然这题还有啥意思。而且显然贪心是错的,不然这题还有啥意思。

既然边长固定为 1,那直接 n 次 BFS 把全源最短路求出来,时间复杂度 O(n ^ 2),这里不说了。

观察数据范围可得,我们只能枚举两个点,对于剩下的两个点,我们需要用常数次枚举找到最优答案。

当然枚举两个点也足够了。我们发现,ad 两个点都需要可以到达 1,我们可以花 O(n) 的时间把 1 可以到达的点找到,同样也是这些点可以到达 1。而对于 b , c 两个点,我们必须 O(n ^ 2) 枚举这两个点,原因也很显然:贪心是错的。

但我们既然要枚举这两个点,就需要在常数时间内找到 a , d 两点。既然上文可以 O(n) 时间内求出一个点能否到达 1,那不妨对于每一组 (x , y),求出能否同时满足 x 1 互相可达且 x y 互相可达。我们已经求出最短路了,这部分也很简单。

那么我们都能处理出来每个点可到达的点中,可以到达 1 的所有点了,肯定不能把所有这些信息全都保留下来,不然最坏情况下和直接枚举是没有区别的。显然对于一个点,我们并不需要存下所有同时可以到达这个点和 1 的所有点,而是只需要分数前若干大的点。不妨假设我们现在在考虑 b 点可到达的 a 点,那么 a 点不能和 c , d 两点相同。如果分数前两大的 a 点分别是最终选择的 c , d 两点,我们就需要再记录下分数第三大的 a 点。而这个过程中不需要额外考虑 1 这个点,因为我们可以手动将它的分数设置为 -\infin,本题中取小于等于 - 4 \times {10} ^ {18} 的数字即可。

综上所述,我们需要在枚举每一组 (x , y) 是否符合“x 1 互相可达且 x y 互相可达”的基础上,动态维护可同时到达某个点和 1 这个点的分数前三大,而这个信息可以用 set 动态维护。

然后就做完了。注意最好不要维护分数第四大及以后的值,因为没用,而且可能过不了。

题目代码

没啥好说的,我觉得这题不难理解。

代码缺省源这种东西自己看着来吧。

又附:一开始我把分数放到第二关键字排序,怒提 80 分。

long long n , m , k;
long long sc[2505]; // score
vector < long long > v[2505];
long long dist[2505][2505];
void bfs(int st)
{
    vector < bool > vis(n + 1 , 0);
    queue < pair < long long , long long > > q;
    q.push(make_pair(st , 0));
    while(q.size())
    {
        pair < long long , long long > a = q.front();
        q.pop();
        if(vis[a.first])
        {
            continue;
        }
        vis[a.first] = 1;
        dist[st][a.first] = a.second;
        for(int i : v[a.first])
        {
            q.push(make_pair(i , a.second + 1));
        }
    }
}
set < pair < long long , long long > > max_get[2505]; // 每个点能到的点中,能回到 1 的点的最大点权 
signed main()
{
    read(n , m , k);
    sc[1] = -LONG_LONG_MAX; 
    for(int i = 2 ; i <= n ; i++)
    {
        read(sc[i]);
    }
    for(int i = 1 ; i <= m ; i++)
    {
        int x , y;
        read(x , y);
        v[x].push_back(y);
        v[y].push_back(x);
    }
    memset(dist , 0x3f , sizeof(dist));
    for(int i = 1 ; i <= n ; i++)
    {
        bfs(i);
    }
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j = 1 ; j <= n ; j++)
        {
            if(dist[j][1] <= k + 1 && dist[i][j] <= k + 1)
            {
                max_get[i].insert(make_pair(sc[j] , j));
            }
            if(max_get[i].size() > 3)
            {
                max_get[i].erase(max_get[i].begin());
            }
        }
    }
    long long ans = 0;
    for(int i = 2 ; i <= n ; i++)
    {
        for(int j = 2 ; j <= n ; j++)
        {
            if(dist[i][j] > k + 1 || i == j)
            {
                continue;
            }
            for(pair < long long , long long > x : max_get[i])
            {
                for(pair < long long , long long > y : max_get[j])
                {
                    int p1 = x.second , p2 = y.second;
                    long long s1 = x.first , s2 = y.first;
                    if(i == p1 || j == p1 || i == p2 || j == p2 || p1 == p2 || p1 == 1 || p2 == 1)
                    {
                        continue;
                    }
                    ans = max(ans , s1 + s2 + sc[i] + sc[j]);
                }
            }
        }
    }
    printnl(ans);
    return 0;
};