SOS DP
SOS DP(S um over S ubsets dynamic programming)
SOS DP
3^n枚举子集
SOS DP是解决类似
这样的问题
对于这个问题,有不同的解法
-
O(4^n)暴力
for(int mask=0;mask<(1<<N);mask++){ for(int i=0;i<(1<<N);i++){ if((mask&i)==i) F[mask]+=A[i]; } } -
O(3^n)暴力
for(int mask=0;mask<(1<<N);mask++){ F[mask]=A[0]; for(i=mask;i>0;i=(i-1)&mask){ F[mask]+=A[i]; } }对于这种解法,我们只枚举了mask的子集,若长度为k,则会枚举2^k次,因此时间复杂度为
\sum_{k=0}^{mask}C_{n}^{k} 2^k= \sum_{k=0}^{mask}C_{n}^{k} (2^k*1^{n-k})=(1+2)^n=3^n -
SOS DP 我们记f[mask][i]为对于数mask的二进制中,第i位是不同的
若
若
如图: )
代码如下
for(int mask=0;mask<(1<<N);mask++){
dp[mask][-1]=A[mask];
for(int i=0;i<N;i++){
if(mask&(1<<i))
dp[mask][i]=dp[mask][i-1]+dp[mask^(1<<i)][i-1];
else dp[mask][i]=dp[mask][i-1];
}
F[mask]=dp[mask][N-1];
}
还可以继续优化
for(int i=0;i<(1<<N);i++)
F[i]=A[i];
for(int i=0;i<N;i++){
for(int mask=0;mask<(1<<N);mask++){
if(mask&(1<<i))
F[mask]+=F[mask^(1<<i)];
}
}
注意:上述的N为数的二进制长度,即
例一 CF165E Compatible Numbers
思路:
对于任意一个数i,我们在数组中找出一个数,使a[i]是i的子集,且a[i]最大
对于数a[i],我们将其取反,记为i',输出f[i']
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=(1<<22);
int n,a[maxn+5],f[maxn+5];
int main(){
cin>>n;
memset(f,-1,sizeof f);
for(int i=1;i<=n;i++){
cin>>a[i];
f[a[i]]=a[i];
}
for(int pos=1;pos<=22;pos++){
for(int i=0;i<(1<<22);i++){
if(i&(1<<(pos-1)))
f[i]=max(f[i],f[i^(1<<(pos-1))]);
}
}
for(int num=1;num<=n;num++){
int now=((1<<22)-1)^a[num];
cout<<f[now]<<" ";
}
return 0;
}
例二 CF449D Jzzhu and Numbers
题目要求与起来为0的方案数
考虑计算补集,ans=总方案-或起来不为0的方案数
记f[i]为或起来为i的方案数
分类讨论,有一位或起来不为0,有两位或起来不为0.....
容斥原理处理即可
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5,p=1e9+7;
typedef long long ll;
int n,a[maxn],f[(int)(1<<20)+5];
int check(int num){
int ans=0;
while(num>0){
ans++;
num=num&(num-1);
}
return ans;
}
ll poww(ll a,ll cnt){
ll ans=1;
while(cnt>0){
if(cnt&1)ans=ans*a%p;
a=a*a%p;
cnt=(cnt>>1);
}
return ans;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
f[a[i]]++;
}
for(int i=20;i>=1;i--){
for(int num=0;num<(1<<20);num++){
if(num&(1<<(i-1)))
f[num^(1<<(i-1))]+=f[num],f[num^(1<<(i-1))]%=p;
}
}
ll sum=poww(2,n)%p;
for(int i=1;i<(1<<20);i++){
ll num=check(i);
if(num%2==0)sum=sum+poww(2,f[i]),sum%=p;
else sum=sum-poww(2,f[i])+p,sum%=p;
}
cout<<sum;
return 0;
}