P3096

· · 个人记录

题意

给定一个有向图,再给定 k 个重要点,q 次询问,每次询问包含 u,v,询问从 u 点到 v 点必须经过重要点的最短路径,请输出 q 次询问的最短路径长度之和。

sol

显然,一开始的想法是dij,发现2e4*2e4开不下,不难发现 k \leq 200 于是我们可以利用离散化的思想,把 k 个重要点编号就做完了。

代码

void dij1(int s,int st){
    priority_queue<node> q;
    for(int i=0;i<=N-10;i++){
        dis1[i][st]=Max;
    }
    dis1[s][st]=0;
    q.push({s,0});
    while(!q.empty()){
        node t=q.top();
        q.pop();
        int u=t.p;
        if(vis[u]) continue;
        for(int i=0;i<ve[u].size();i++){
            int v=ve[u][i].fi,w=ve[u][i].se;
            if(dis1[v][st]>dis1[u][st]+w){
                dis1[v][st]=dis1[u][st]+w;
                q.push({v,dis1[v][st]});
            }
        }
    }
}
void dij2(int s,int st){
    priority_queue<node> q;
    for(int i=0;i<=N-10;i++){
        dis2[i][st]=Max;
    }
    dis2[s][st]=0;
    q.push({s,0});
    while(!q.empty()){
        node t=q.top();
        q.pop();
        int u=t.p;
        if(vis[u]) continue;
        for(int i=0;i<ve[u].size();i++){
            int v=ve[u][i].fi,w=ve[u][i].se;
            if(dis2[v][st]>dis2[u][st]+w){
                dis2[v][st]=dis2[u][st]+w;
                q.push({v,dis2[v][st]});
            }
        }
    }
}
signed main(){

//   freopen("a.in","r",stdin);
//   freopen("a.out","w",stdout);
    cin>>n>>m>>k>>q;
    for(int i=1;i<=m;i++){
        cin>>u[i]>>v[i]>>w[i];
    }
    for(int i=1;i<=k;i++){
        cin>>a[i];
        mp[a[i]]=i;
    }
    for(int i=1;i<=m;i++){
        ve[u[i]].push_back({v[i],w[i]});
    }
    for(int i=1;i<=k;i++){
        dij1(a[i],mp[a[i]]);
    }
    for(int i=1;i<=n;i++){
        ve[i].clear();
    }
    for(int i=1;i<=m;i++){
        ve[v[i]].push_back({u[i],w[i]});
    }
    memset(vis,0,sizeof vis);
    for(int i=1;i<=k;i++){
        dij2(a[i],mp[a[i]]);
    }
    int sum=0,cnt=0;
    while(q--){
        int l,r;
        cin>>l>>r;
        int minn=Max;
        if(mp[l]){
            minn=min(minn,dis1[r][mp[l]]);
        }
        if(mp[r]){
            minn=min(minn,dis2[l][mp[r]]);
        }
        if(minn==Max){

            for(int i=1;i<=k;i++){
                minn=min(minn,dis1[r][i]+dis2[l][i]);
            }

        }
        if(minn==Max) continue;
            cnt++;
            sum+=minn;
    }
    cout<<cnt<<endl;
    cout<<sum<<endl;
    return 0;
}