题解:P10206 [JOI 2024 Final] 建设工程 2 / Construction Project 2

· · 题解

闲话

这个题目有点板子吧,小技巧比较多。

正文

一眼看到这个题目,我就想到了特判,当从 ST 的最短路径都小于等于 K 的时候,显然就是任意两个点之间都可以建边,而且不会影响到最短路径,这部分的答案就是 \frac{n\times (n-1)}{2}

好了,特判完,就来想想正解。

先考虑一个拆分的小技巧:对于从 ST 的最短路,我们可以拆成三部分,分别是 dis(S,u)dis(u,v)dis(v,T),其中 dis(i,j) 表示从点 i 到点 j 的最短路径。

这个技巧有什么用呢?

很简单,我们可以考虑枚举 u 点,然后对于任意的 v 点,都要满足这个性质:

dis(S,u)+L+dis(v,T)\le K

显然,我们用 L 替代了 dis(u,v),这表示我们在 uv 点之间建了新边。

因为我们枚举了 u 点,所以我们需要快速地得到满足上述条件的 v 点的数量,刚好,我们实际上并不需要知道究竟是哪些点,只需要知道点的数量,所以我们将 dis(v,T) 单独拿出来排序,然后十分套路地使用二分查找就可以了,甚至你不需要手写,因为upper_bound就可以做到这个。

所以代码具体就是:

但是有问题啊!

想想,如果对于一对 (u,v) 同时满足:

dis(S,u)+L+dis(v,T)\le K dis(S,v)+L+dis(u,T)\le K

那根据上述的算法,会在枚举到 u 点时将 v 点计算入内,在枚举到 v 点时将 u 点计算入内,但是它们本应该只计算一次!

但是我没考虑到上面的情况,依然通过了这道题,说明数据过水上面的情况是不成立的,考虑简单证明一下。

采用反证法,假设存在一对 (u,v) 满足上面的情况,那么也满足:

dis(S,u)+dis(v,T)+dis(S,v)+dis(u,T)+2\times L \le 2\times K

但我们将前面四项两两匹配:

(dis(S,u)+dis(u,T))+(dis(S,v)+dis(v,T))+2\times L\le 2\times K

明显就是:

2\times dis(S,T)+2\times L\le 2\times K

把系数消去:

dis(S,T)+L\le K

则:

dis(S,T)\le K-L

明显这里已经被我们在一开始特判掉了,所以我们不满足算重的条件,也就不会算重。

代码

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define V vector
#define FOR(i,a,b) for(int i=(int)(a);i<=(int)(b);i++)
#define pb push_back
const int INF=1e18+10;
using P=array<int,2>;
struct node{
    int id,dis;
    friend bool operator <(const node &a,const node &b){
        return a.dis>b.dis;
    }
};
V<int>dij(V<V<P> >&e,int st,int n){
    V<int>dis(n+1,INF);V<bool>vis(n+1,false);
    priority_queue<node>q;q.push({st,0});dis[st]=0;
    while(!q.empty()){
        int u=q.top().id;q.pop();
        if(vis[u])continue;
        vis[u]=true;
        for(auto i:e[u]){
            int v=i[0],w=i[1];
            if(dis[v]>dis[u]+w){
                dis[v]=dis[u]+w;
                q.push({v,dis[v]});
            }
        }
    }
    return dis;
}
signed main(){
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    int n,m,s,t,l,k;
    cin>>n>>m>>s>>t>>l>>k;
    V<V<P> >e(n+1);
    FOR(i,1,m){
        int a,b,c;
        cin>>a>>b>>c;
        e[a].pb({b,c});
        e[b].pb({a,c});
    }
    V<int>dis1=dij(e,s,n),dis2=dij(e,t,n);
    if(dis1[t]<=k){
        cout<<(n-1)*n/2;
        return 0;
    }
    V<int>dis3=dis2;
    int ans=0;
    sort(dis3.begin()+1,dis3.end());
//  FOR(i,1,n) cout<<dis3[i]<<" ";
//  cout<<"\n";
    FOR(i,1,n){
        int sum=upper_bound(dis3.begin()+1,dis3.end(),k-l-dis1[i])-dis3.begin()-1;
        ans+=sum;
    }
    cout<<ans;
    return 0;
}