题解 P3959 [NOIP2017 提高组] 宝藏

· · 个人记录

题目描述

一共 n 个宝藏(即 n 个点),m 条待开发隧道,每条隧道给出连接的两个宝藏的编号 u,v 和长度 L。到达一个宝藏所在地后,你可以获得该地的宝藏。你可以任意选择从一个宝藏所在地出发,开发一条隧道的代价是 L\ast KL 为这条隧道的长度,K 为从起始点到这条隧道的起点所经过的点的数量。如果该隧道连接的两个宝藏都已经被挖掘,那么这条隧道不能再被开发,已开发过的隧道可以重复经过。求最小的代价,能获得所有宝藏。

对于 70\% 的数据:1\leq n\leq 80\leq m \leq 10^3L\leq 5\ast 10^3

对于 100\% 的数据:1\leq n \leq 120\leq m \leq 10^3L\leq 5\ast 10^5

思路分析

可以想到比较暴力的方法,直接枚举全排列来得到答案,在本题中,我们还需要知道当前点是从哪个点扩展过来的,对于 70\% 的数据,8!=40320,直接暴力搜索就可以通过。

void dfs(int tot){
    if(tot==n){
        ans=min(ans,sum);
        return ;
    }
    if(sum>=ans)
        return ;
    for(int i=1;i<=n;i++){
        if(!lev[i])
            continue;
        for(int j=1;j<=n;j++){
            if(lev[j] || i==j || mp[i][j]==inf)
                continue;
            lev[j]=lev[i]+1;
            sum+=mp[i][j]*lev[i];
            dfs(tot+1);
            lev[j]=0;
            sum-=mp[i][j]*lev[i];
        }
    }
}

signed main(){
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            mp[i][j]=inf;
        }
    }
    for(int i=1;i<=m;i++){
        u=read(),v=read(),w=read();
        mp[u][v]=min(mp[u][v],w);
        mp[v][u]=min(mp[v][u],w);
    }
    for(int i=1;i<=n;i++){
        lev[i]=1;
        sum=0;
        dfs(1);
        lev[i]=0;
    }
    printf("%lld\n",ans);
    return 0;
}

然后我们就获得了 70 分的好成绩。

继续对搜索进行优化,考虑最优性剪枝。

首先我们来看需要用到的几个数组。

int to[15][15];//编号为 i 的点连接的第 j 个点的编号为 to[i][j]
int d[15];//编号为 i 的点连接了 d[i] 个点
int tmp;//用于最优性剪枝
int ans=inf;//统计最终答案
int sum;//dfs 过程中计算代价
int id[15];//当前枚举的第 i 个点的编号为 id[i]
int cnt;//当前已经扩展到了多少个点

红框中的点为已遍历节点,id[i] 节点为当前扩展节点,与 id[i] 相连的所有点,都是到达这些点所需要走的最短的边,我们假设这些最短的边,都可以从 id[i] 出发(实际的图可能不长这样),计算出理论最小代价,然后用之前计算的实际代价加上理论最小代价,如果这样都会超过之前计算出的最小的 ans,那么就没有再计算下去的必要了,因为继续往下扩展没有扩展过的点,实际代价只会更大。

int p;
bool cmp(int a,int b){
    return mp[p][a]<mp[p][b];
}

void dfs(int num,int pos){//从第 num 个点开始扩展,第 num 个点从它连接的第 pos 个点开始扩展。
    if(cnt==n){
        ans=min(ans,sum);
        return ;
    }
    for(int i=num;i<=cnt;i++){
        if(sum+tmp*lev[id[i]]>=ans)
            return ;
        for(int j=pos;j<=d[id[i]];j++){
            if(!lev[ to[id[i]][j] ]){
                id[++cnt]=to[id[i]][j];
                lev[ id[cnt] ]=lev[ id[i] ]+1;
                tmp-=mp[ id[cnt] ][ to[id[cnt]][1] ];
                sum+=mp[ id[i] ][ id[cnt] ]*lev[ id[i] ];
                dfs(i,j+1);
                sum-=mp[ id[i] ][ id[cnt] ]*lev[ id[i] ];
                tmp+=mp[ id[cnt] ][ to[id[cnt]][1] ];
                lev[ id[cnt] ]=0;
                id[cnt--]=0;
            }
        }
        pos=1;
    }
}

signed main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            mp[i][j]=inf;
        }
    }
    for(int i=1;i<=m;i++){
        u=read(),v=read(),w=read();
        if(mp[u][v]==inf){
            to[u][++d[u]]=v;
            to[v][++d[v]]=u;
        }
        mp[u][v]=min(mp[u][v],w);
        mp[v][u]=min(mp[v][u],w);
    }
    for(int i=1;i<=n;i++){
        p=i;
        sort(to[i]+1,to[i]+1+d[i],cmp);
        tmp+=mp[i][to[i][1]];
    }
    for(int i=1;i<=n;i++){
        cnt=1,sum=0;
        lev[i]=1;
        id[1]=i;
        tmp-=mp[i][to[i][1]];
        dfs(1,1);
        lev[i]=0;
        tmp+=mp[i][to[i][1]];
    }   
    printf("%lld\n",ans);
    return 0;
}