P3627 [APIO2009] 抢掠计划

· · 个人记录

题目

思路

首先思路很好想,先缩点,然后遍历一遍。

但是现在要讲讲一些拓扑缩点的注意事项。

  1. 对于缩点,去重边,自环,是必要的。
  2. 当你在遍历 DAG 图时,不要想当然,dfs 打标记是错的,不打标记是你得仔细思考是不是 O(n)

code

#include<bits/stdc++.h>
using namespace std;
const int N=5e5+110;
int read(){
    int x=0,f=1;char c=getchar();
    while(c>'9' || c<'0'){if(c=='-')f=-1;c=getchar();}
    while(c>='0' && c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return x*f;
}
vector<int>G[N];
int n,m,bk[N];long long ans,va[N],sva[N];
int head[N],to[N],last[N],tot;
int dfn[N],cnt,low[N];
int s[N],top;
int spct,col[N],ending[N],P,S,send[N];
void add(int u,int v){to[++tot]=v,last[tot]=head[u],head[u]=tot;return;}
void tarjan(int u){
    dfn[u]=low[u]=++cnt;bk[u]=1;s[++top]=u;
    for(int i=head[u];i;i=last[i]){
        int v=to[i];
        if(!dfn[v]){tarjan(v);low[u]=min(low[u],low[v]);}
        else if(bk[v])low[u]=min(low[u],low[v]);
    }
    if(dfn[u]==low[u]){
        spct++;
        while(s[top]!=u){
            col[s[top]]=spct;
            sva[spct]+=va[s[top]];
            if(ending[s[top]])send[spct]=1;
            bk[s[top]]=0;top--;
        }
        col[u]=spct;
        sva[spct]+=va[u];
        if(ending[u])send[spct]=1;
        bk[u]=0;top--;//我觉得这么写很清楚吧
    }
    return;
}
long long dis[N];
int ind[N];
map<pair<int,int>,int>mp; 
signed main(){
    n=read(),m=read();int t1,t2;
    for(int i=1;i<=m;i++){t1=read(),t2=read();add(t1,t2);}
    for(int i=1;i<=n;i++)va[i]=read();
    S=read(),P=read();
    for(int i=1;i<=P;i++)ending[read()]=1;
    for(int i=1;i<=n;i++)if(!dfn[i])tarjan(i);
    for(int u=1;u<=n;u++)
        for(int i=head[u];i;i=last[i]){
            int v=to[i];if(col[u]!=col[v] && mp[make_pair(col[u],col[v])]==0)G[col[u]].push_back(col[v]),mp[make_pair(col[u],col[v])]=1,ind[col[v]]++;
            //用map去重边(不去也行,去了时间上还是其他上都好一些吧)
        }
      //使用链式前向星建原图的边,用vector建缩点后的边。
    queue<int>q;
    for(int i=1;i<=spct;i++)if(!ind[i])q.push(i);
    memset(dis,128,sizeof(dis));
    dis[col[S]]=sva[col[S]];
    while(!q.empty()){
        int u=q.front();q.pop();
        for(int i=0;i<G[u].size();i++){
            int v=G[u][i];
            dis[v]=max(dis[v],dis[u]+sva[v]);
            if(--ind[v]==0)q.push(v);
        }
    }
    for(int i=1;i<=spct;i++)if(send[i])ans=max(ans,dis[i]);
    printf("%lld",ans);
    return 0; 
}