浅谈差分约束和同余最短路

· · 算法·理论

差分约束

定义

定义 差分约束系统 为一组形如

\begin{cases} x_{c_1}-x_{c'_1} \leq y_1 \\x_{c_2}-x_{c'_2} \leq y_2 \\ \cdots\\ x_{c_m} - x_{c'_m}\leq y_m\end{cases}

n 元一次不等式,其中 x,y 为常数。

差分约束是用来求解该不等式的的一种图论建模算法。

过程

发现差分约束系统中的每个约束条件 x_i-x_j \le y_k 可以通过变形得到 x_i \le x_j+y_k。注意到该式与最短路松弛操作(dis_x=dis_y+w_{x,y})相似,于是我们将每个变量 x_i 看做图中的一个节点,对于每个约束 x_i-x_j \le y_k,从 x_jx_i 连一条边权为 y_k 的有向边。最后建立一个超级原点 0,其向所有 i 连一条边权为 0 的有向边,该操作可以方便代码书写。

最后求得的解为 \begin{cases} x_1=dis_1 \\x_2=dis_2 \\ \cdots\\ x_n=dis_n \end{cases},其中 dis_ii0 的最短路径。

注意到该差分约束系统可能无解,因此在图中判断是否存在负环即可。

该算法求解的是差分约束系统的一组特解,想要求出通解可以将所有 x_i 同时加上一个常数 d,由于是同时增加,所以所有 x_{c_i}-x_{c'_i} 不变,因此该解仍满足原不等式组。

一般来说,差分约束建模得出的图中有负权边,因此一般使用 SPFA 算法进行求解,时间复杂度为 O(n\times m)

证明

由于是进行最短路操作,因此对于任意 i,必然有 dis_i-dis_{j} \le w_jji 的所有邻居),如不满足,将 dis_i 设为 dis_{j}+w_j 必定更优,不符合最短路的性质。

代码实现

以洛谷 P5960 为例:

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
int n,m;
int dis[5005],cnt[5005];
bool vis[5005];
struct node{
    int x,w;
};
vector<node> v[5005];
bool spfa(int s){
    memset(dis,0x3f,sizeof(dis));
    queue<int> que;
    que.push(s);
    dis[s] = 0;
    vis[s] = 1;
    while(!que.empty()){
        int x = que.front();
        que.pop();
        vis[x] = 0;
        for(auto y:v[x]){
            if(dis[x]+y.w<dis[y.x]){
                dis[y.x] = dis[x]+y.w;
                if(!vis[y.x]){
                    que.push(y.x);
                    vis[y.x] = 1;
                    if(++cnt[y.x]>=n){
                        return 1;
                    }
                }
            }
        }
    }
    return 0;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> m;
    for(int i=1;i<=m;i++){
        int x,y,w;
        cin >> x >> y >> w;
        v[y].push_back((node){x,w});
    }
    n++;
    for(int i=1;i<n;i++){
        v[n].push_back((node){i,0});
    }
    if(spfa(n)){
        cout << "NO";
    }
    else{
        for(int i=1;i<n;i++){
            cout << dis[i] << " ";
        }
    }
    return 0;
}

同余最短路

定义

同余最短路算法是用于求解形如“给出 n 个数 a_{1,2,\cdots,n},求有多少个非负整数 b 满足 b \in [l,r] 能使得 \displaystyle \sum_{i=1}^n a_ix_i=b 存在非负整数解”一类的问题。

过程

首先,我们可以分别求出 b \in [0,r] 的答案和 b \in [0,l-1] 的答案,做差即可得到答案。

现在考虑如何求解 b \in [0,h] 之间的答案。

考虑将原式变形。在 a_{1,2,\cdots,n} 随意选择一个模数,如 a_1。此时可将原式变形为 \displaystyle \sum_{i=1}^n a_ix_i=b \bmod a_1+k\times a_1,进而得到 \displaystyle a_1(x_1-k)+\sum_{i=2}^n a_ix_i=b \bmod a_1。因此对于任意的 b 为原方程的一个解,必然有 b'=b+k \times a_i 为原方程的解。

考虑使用动态规划解决该问题。我们用 dis_i 代表 b \bmod a_1=ib 的最小值。

因此我们有如下转移方程:

f_{(u+a_v) \bmod a_1}=f_u+a_v

相当于在 f_u 的方程(\displaystyle \sum_{i=1}^n a_ix_i=u)基础上将 x_v1 就得到了 f_{(u+a_j) \bmod a_1} 的方程(\displaystyle \sum_{i=1}^{v-1} a_ix_i+a_v(x_v+1)+\sum_{i=v+1}^{n}=u+a_j)。

发现这相当于在图中建 u \stackrel{a_v}{\longrightarrow} (u+a_v) \bmod a_1 这条边。建完边后在图上做从 0 开始的最短路 dis_i 即可求出 f_i=dis_i

由于对于任意的 b 为原方程的一个解,必然有 b'=b+k \times a_i 为原方程的解,因此 b \in [0,h] 的答案即为 \displaystyle \sum_{i=0}^{a_1-1} (\left \lfloor \frac{h - ans_i}{a_1} \right \rfloor+1)

实际代码实现时并不需要把图建出来。代码中要特殊处理 a_i=0 的情况,直接去除 a_i=0 的项即可(因为其对答案没有影响)。

由于图中有 a_1 个点,n \times a_1 条边,因此用 Dijkstra 求解最短路时间复杂度为 O(n \times a_1 \log a_1),用 SPFA 求解最短路时间复杂度为 O(a_1^2 \times n),但是由于图是建模出来的所以实际上达不到该时间复杂度,实际测试中比 Dijkstra 快。注意到时间复杂度与 a_1 有关,且 a 的顺序不影响答案,因此将 a_i 升序排序可以加速该算法。

代码

以洛谷 P2371 为例。

SPFA 写法

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
int n,l,r;
int a[15],dis[500005];
bool vis[500005];
int calc(int h){
    memset(dis,0x3f,sizeof(dis));
    memset(vis,0,sizeof(vis));
    queue<int> que;
    dis[0] = 0;
    vis[0] = 1;
    que.push(0);
    while(!que.empty()){
        int x = que.front();
        que.pop();
        vis[x] = 0;
        for(int i=2;i<=n;i++){
            int y = (x+a[i])%a[1];
            if(dis[x]+a[i]<dis[y]){
                dis[y] = dis[x]+a[i];
                if(!vis[y]){
                    vis[y] = 1;
                    que.push(y);
                }
            }
        }
    }
    int res = 0;
    for(int i=0;i<a[1];i++){
        if(h>=dis[i]){
            res+=(h-dis[i])/a[1]+1;
        }
    }
    return res;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> l >> r;
    int m = 0;
    for(int i=1;i<=n;i++){
        int x;
        cin >> x;
        if(x){
            a[++m] = x;
        }
    }
    n = m;
    if(!n){
        cout << 0;
        return 0;
    }
    sort(a+1,a+n+1);
    cout << calc(r)-calc(l-1);
    return 0;
}

Dijkstra 写法

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
int n,l,r;
int a[15],dis[500005];
bool vis[500005];
struct node{
    int x,w;
    bool operator < (const node &b) const{
        return b.w<w;
    }
};
int calc(int h){
    memset(dis,0x3f,sizeof(dis));
    memset(vis,0,sizeof(vis));
    priority_queue<node> que;
    dis[0] = 0;
    que.push((node){0,0});
    while(!que.empty()){
        node x = que.top();
        que.pop();
        if(vis[x.x]){
            continue;
        }
        vis[x.x] = 1;
        for(int i=2;i<=n;i++){
            int y = (x.x+a[i])%a[1];
            if(dis[x.x]+a[i]<dis[y]){
                dis[y] = dis[x.x]+a[i];
                que.push((node){y,dis[y]});
            }
        }
    }
    int res = 0;
    for(int i=0;i<a[1];i++){
        if(h>=dis[i]){
            res+=(h-dis[i])/a[1]+1;
        }
    }
    return res;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> l >> r;
    int m = 0;
    for(int i=1;i<=n;i++){
        int x;
        cin >> x;
        if(x){
            a[++m] = x;
        }
    }
    n = m;
    if(!n){
        cout << 0;
        return 0;
    }
    sort(a+1,a+n+1);
    cout << calc(r)-calc(l-1);
    return 0;
}