SOS DP

· · 个人记录

SOS DP(S um over S ubsets dynamic programming)

SOS DP

3^n枚举子集

SOS DP是解决类似

F[mask]=\sum_{i\subseteq mask}^{}A[i]

这样的问题

对于这个问题,有不同的解法

  1. O(4^n)暴力

    for(int mask=0;mask<(1<<N);mask++){
    for(int i=0;i<(1<<N);i++){
        if((mask&i)==i)
            F[mask]+=A[i];
    }
    }
  2. O(3^n)暴力

    for(int mask=0;mask<(1<<N);mask++){
    F[mask]=A[0];
    for(i=mask;i>0;i=(i-1)&mask){
        F[mask]+=A[i];
    }
    }

    对于这种解法,我们只枚举了mask的子集,若长度为k,则会枚举2^k次,因此时间复杂度为

    \sum_{k=0}^{mask}C_{n}^{k} 2^k= \sum_{k=0}^{mask}C_{n}^{k} (2^k*1^{n-k})=(1+2)^n=3^n
  3. SOS DP 我们记f[mask][i]为对于数mask的二进制中,第i位是不同的

(mask)_2的第i位为0,则

f[mask][i]=f[mask][i-1]

(mask)_2的第i位为1,则

f[mask][i]=f[mask][i-1]+f[mask\oplus 2^i][i-1]

如图: )

代码如下

for(int mask=0;mask<(1<<N);mask++){
    dp[mask][-1]=A[mask];
    for(int i=0;i<N;i++){
        if(mask&(1<<i))
            dp[mask][i]=dp[mask][i-1]+dp[mask^(1<<i)][i-1];
        else dp[mask][i]=dp[mask][i-1];
    }
    F[mask]=dp[mask][N-1];
}

还可以继续优化

for(int i=0;i<(1<<N);i++)
    F[i]=A[i];
for(int i=0;i<N;i++){
    for(int mask=0;mask<(1<<N);mask++){
        if(mask&(1<<i))
            F[mask]+=F[mask^(1<<i)];
    }
}

注意:上述的N为数的二进制长度,即N=logn,时间复杂度为O(N*2^N)O(nlogn)

例一 CF165E Compatible Numbers

思路:

对于任意一个数i,我们在数组中找出一个数,使a[i]是i的子集,且a[i]最大

对于数a[i],我们将其取反,记为i',输出f[i']

代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn=(1<<22);
int n,a[maxn+5],f[maxn+5];
int main(){
    cin>>n;
    memset(f,-1,sizeof f);
    for(int i=1;i<=n;i++){
        cin>>a[i];
        f[a[i]]=a[i];
    }
    for(int pos=1;pos<=22;pos++){
        for(int i=0;i<(1<<22);i++){
            if(i&(1<<(pos-1)))
                f[i]=max(f[i],f[i^(1<<(pos-1))]);
        }
    }
    for(int num=1;num<=n;num++){
        int now=((1<<22)-1)^a[num];
        cout<<f[now]<<" ";
    }

    return 0;
}

例二 CF449D Jzzhu and Numbers

题目要求与起来为0的方案数

考虑计算补集,ans=总方案-或起来不为0的方案数

记f[i]为或起来为i的方案数

分类讨论,有一位或起来不为0,有两位或起来不为0.....

容斥原理处理即可

代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5,p=1e9+7;
typedef long long ll;
int n,a[maxn],f[(int)(1<<20)+5];
int check(int num){
    int ans=0;
    while(num>0){
        ans++;
        num=num&(num-1);
    }
    return ans;
}
ll poww(ll a,ll cnt){
    ll ans=1;
    while(cnt>0){
        if(cnt&1)ans=ans*a%p;
        a=a*a%p;
        cnt=(cnt>>1);
    }
    return ans;
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        f[a[i]]++;
    }
    for(int i=20;i>=1;i--){
        for(int num=0;num<(1<<20);num++){
            if(num&(1<<(i-1)))
                f[num^(1<<(i-1))]+=f[num],f[num^(1<<(i-1))]%=p;
        }
    }
    ll sum=poww(2,n)%p;
    for(int i=1;i<(1<<20);i++){
        ll num=check(i);
        if(num%2==0)sum=sum+poww(2,f[i]),sum%=p;
        else sum=sum-poww(2,f[i])+p,sum%=p; 
    }
    cout<<sum;
    return 0;
}