P3356 火星探险问题

· · 题解

又回来写费用流了。

解法

众所周知,网络流的题你一般看不出来要用网络流,所以在此我解释一下这个题是如何想到用网络流的。(当然如果你是看了 tag 知道的我就没话说了)

看一下题目要求的目标:

  1. 最大化探测车的数量;
  2. 在 1 的前提下,最大化岩石的数量。

这个是不是长得有点像费用流?在保证最大流的前提下满足费用最大/小,好像是这么回事啊。

再看,这是一个网格图,看到网格图一般都会想到 dp 或者贪心,但是 dp 状态好像很难搞,贪心显然是错误的。

最后看数据范围,1\le n,p,q\le 35,正好是网络流算法所能够接受的。

于是能够确定:这是一道最大费用最大流的题。

好,既然这样,我们想一想怎么建图。

首先将题目中的每一条信息转换成网络流图的东西,可以得到:除了是障碍的格子,每一个格子都与它下面的和右边的格子有一条有向边,是石头的格子有点权,为 1(因为只有一个石头)。因为流图需要处理边权,于是我们把每一个格子 (i,j) 拆成两个点 (i,j)(i',j'),然后从 (i,j)(i',j') 连一条容量无限,花费 0 的有向边。对于不是障碍的格子,将 (i',j') 连向 (i+1,j)(i,j+1),容量无限,费用为 0,表示这些点之间可以随意的走。对于是石头的格子,我们在 (i,j)(i',j') 之间再连一条容量为 1,花费为 1 的有向边,表示这个石头只有一个,并且只能拿一次。最后建立超级源 S 和超级汇 T,分别在 S(1,1) 间、 (x_p',y_q')T 间连一条容量为 n,费用为 0 的有向边。这样就完成了建图。

然后是输出的问题。我们可以这样搞:从起点出发,每路过一个点 u 就走其反向弧还有流量的邻边,然后将这条反向弧的容量减 1,一直到到达终点就结束。对于这个邻边所连的点 v,如果在网格图中它在 u 的下方就说明要往南走,在右方就说明要往东走,然后输出就行了。

具体可以看代码理解。

Code

#include <bits/stdc++.h>
#define loop(i,a,b) for(int i=(a);~i;i=(b))
#define Mem(a,b) memset ((a),(b),sizeof((a)))
#define eb emplace_back
#define pb push_back
using namespace std;
typedef long long ll;
constexpr int N = 5000 + 15, inf = 1e9;

namespace FAST_IO {
//快读快写已为您省略
}using namespace FAST_IO;

int p, n, m, S, T;
int head[N], to[N << 1], nxt[N << 1], idx;
ll val[N << 1], c[N << 1];
int id[N][N], cnt;
ll h[N], dis[N];
int pre_e[N], pre_d[N];
bitset <N> vis;

void add (int u, int v, ll w1, ll w2) {
    to[idx] = v;
    val[idx] = w1, c[idx] = w2;
    nxt[idx] = head[u];
    head[u] = idx ++;
}

void SPFA () {
    queue <int> q;
    Mem (h, 0xcf);
    h[S] = 0;
    q.emplace (S);
    vis[S] = 1;
    while (!q.empty ()) {
        int u = q.front ();
        q.pop ();
        vis[u] = 0;
        loop (i, head[u], nxt[i]) {
            int v = to[i];
            if (val[i] && h[v] < h[u] + c[i]) {
                h[v] = h[u] + c[i];
                if (!vis[v]) {
                    vis[v] = 1;
                    q.emplace (v);
                }
            }
        }
    }
}

struct node {
    ll dis;
    int u;
    bool operator < (const node &x) const {
        return dis < x.dis;
    }
};

bool Dijkstra () {
    priority_queue <node> q;
    vis.reset ();
    Mem (dis, 0xcf);
//  print (dis[S], '\n');
    dis[S] = 0;
    q.push ({0, S});
    while (!q.empty ()) {
        int u = q.top().u;
        q.pop ();
        if (vis[u]) continue;
        vis[u] = 1;
        loop (i, head[u], nxt[i]) {
            int v = to[i];
            ll w = c[i] + h[u] - h[v];
            if (val[i] && dis[v] < dis[u] + w) {
                dis[v] = dis[u] + w;
                pre_e[v] = i;
                pre_d[v] = u;
                q.push ({dis[v], v});
            }
        }
    }
//  print (dis[T], '\n');
    return dis[T] > -3472328296227680305;
}

ll solve () {
    ll ans = 0;
    SPFA ();
    while (Dijkstra ()) {
//      cout << 1;
        ll mf = inf;
        for (int i = 0; i <= cnt * 2 + 3; ++ i) h[i] += dis[i];
        for (int i = T; i ^ S; i = pre_d[i]) mf = min (mf, val[pre_e[i]]);
        for (int i = T; i ^ S; i = pre_d[i]) {
            val[pre_e[i]] -= mf;
            val[pre_e[i] ^ 1] += mf;
        }
        ans += mf;
//      print (h[T], '\n');
    }
    return ans;
}//最大费用最大流

void getans (int p) {//输出方案
    int x = 1, y = 1;
    while (true) {
        if (x == n && y == m) return ;
        int u = id[x][y] + cnt;
        loop (i, head[u], nxt[i]) {
            int v = to[i];
            if (!val[i ^ 1]) continue;
            if (v == id[x + 1][y]) {
                print (p, ' ');
                print (0, '\n');
                ++ x;
                val[i ^ 1] --;
                break;
            }
            if (v == id[x][y + 1]) {
                print (p, ' ');
                print (1, '\n');
                ++ y;
                val[i ^ 1] --;
                break;
            }//如上所言
        }
    }
}

int main () {
    Mem (head, -1);
    read (p, m, n);
    for (int i = 1; i <= n; ++ i) for (int j = 1; j <= m; ++ j) id[i][j] = ++ cnt;
    S = 0, T = cnt * 2 + 2;
    for (int i = 1; i <= n; ++ i) {
        for (int j = 1; j <= m; ++ j) {
            int x;
            read (x);
            add (id[i][j], id[i][j] + cnt, inf, 0);
            add (id[i][j] + cnt, id[i][j], 0, 0);
            if (x != 1) {
                if (i + 1 <= n) {
                    add (id[i][j] + cnt, id[i + 1][j], inf, 0);
                    add (id[i + 1][j], id[i][j] + cnt, 0, 0);
                }

                if (j + 1 <= m) {
                    add (id[i][j] + cnt, id[i][j + 1], inf, 0);
                    add (id[i][j + 1], id[i][j] + cnt, 0, 0);
                }

                if (x == 2) {
                    add (id[i][j], id[i][j] + cnt, 1, 1);
                    add (id[i][j] + cnt, id[i][j], 0, -1);
                }
            }
        }
    }
    add (S, id[1][1], p, 0);
    add (id[1][1], S, 0, 0);
    add (id[n][m] + cnt, T, p, 0);
    add (T, id[n][m] + cnt, 0, 0);//建图
    ll ans = solve ();
    for (int i = 1; i <= ans; ++ i) getans (i);
    return 0;
}