【学习笔记】WQS二分

· · 算法·理论

介绍

WQS二分一般用于处理一类带限制的题目,如恰好选 k 个元素的题目,但是它有使用的前提,那就是原问题具有凹凸性。

举个例子,我们现在有一个上凸的函数,根据定义有它的斜率单调递减。此时考虑二分斜率 k,然后找出被切点,设它为 (x,F(x)),它表示当选择 x 个元素时答案为 F(x)

那么,如何求被切点呢?我们先考虑这个函数的意义,x 表示选了几个元素,而我们的切线表达式为 y=kx+b,式子中的 kx 表示每个目标元素的代价所增加的值,那么我们可以给目标元素的代价减去 k,再使用一些方法得到当前的答案,注意还需要求出当前答案中目标元素的个数。

假设我们最优答案的斜率为 k,这条切线称其为最优切线,此时选 x 个目标元素,那么此时的答案(包含改变的代价)即为 kx+b,而去掉改变的代价 kx 就变为了 b,此为最终答案。不难发现其为最优切线在 x=0 时的函数值,即最优切线的纵截距。

有的时候我们会发现某个斜率 k 会切到多个点,此时我们需要根据具体的题目来解决。例如:原问题上凸,要求答案最小,即要求最优切线的纵截距最小,画图可以知道越往左的节点可能的纵截距越小。

例题

[国家集训队] Tree I

经典例题。

F(x) 为选择 x 条白色边时的答案,其中 x 为选择的白色边的个数,所以越往左选择的白色边越少,也就是说越往左白色边的边权越大(边权越大被选中的可能性越小)。左边的点对应切线的斜率较大(具体证明去看题解),又因为左边的点选择的白色边数量较少。所以,k 即为白色边边权增加的值(白色边边权越大,出现在最小生成树中的概率越小)。我们在遇到一个切线切到多个节点的情况时,由于需要答案最小,所以越大的 k 越好。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5,M=1e5+5;
struct edge{int u,v,w,col;}e[M];
bool cmp(edge x,edge y){return(x.w!=y.w?x.w<y.w:x.col<y.col);}
class DSU
{
    private:
        int fa[N];
        int find(int x)
        {
            if(fa[x]==x)return x;
            return fa[x]=find(fa[x]);
        }
    public:
        void init(){for(int i=1;i<N;i++)fa[i]=i;}
        void merge(int x,int y){fa[find(x)]=find(y);}
        bool same(int x,int y){return find(x)==find(y);}
}dsu;
int n,m;
pair<int,int>check(int k)
{
    for(int i=1;i<=m;i++)
        if(e[i].col==0)e[i].w+=k;
    sort(e+1,e+m+1,cmp);
    int res=0,cnt=0;
    dsu.init();
    for(int i=1;i<=m;i++)
    {
        int u=e[i].u+1,v=e[i].v+1,w=e[i].w,col=e[i].col;
        if(!dsu.same(u,v))res+=w,cnt+=(col==0?1:0),dsu.merge(u,v);
    }
    for(int i=1;i<=m;i++)
        if(e[i].col==0)e[i].w-=k;
    return make_pair(res,cnt);
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int need;
    cin>>n>>m>>need;
    for(int i=1;i<=m;i++)cin>>e[i].u>>e[i].v>>e[i].w>>e[i].col;
    int l=-101,r=101;
    while(l+1<r)
    {
        int mid=l+r>>1;
        if(check(mid).second>=need)l=mid;
        else r=mid;
    }
    cout<<check(l).first-need*l;
    return 0;
}

最小度限制生成树

F(x)s 度数为 x 时的答案,也就是 x 越小连接 s 的边边权越大。可以证明左边的点对应切线的斜率较大,故二分的斜率 k 为连接 s 的边边权减少的值。若同时切到多个节点,选择最右边的点(即对应切线斜率最大的节点)。

注意本题需要判断无解的情况,若连接 s 的边边权均为正无穷的情况下 s 的度数依然大于 k(此为 s 度数最小的情况),或者 s 的度数小于 k,就无解。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5,M=5e5+5;
struct edge{int u,v,w;}e[M];
class DSU
{
    private:
        int fa[N];
        int find(int x)
        {
            if(fa[x]==x)return x;
            return fa[x]=find(fa[x]);
        }
    public:
        void init(){for(int i=1;i<N;i++)fa[i]=i;}
        void merge(int x,int y){fa[find(x)]=find(y);}
        bool same(int x,int y){return find(x)==find(y);}
}dsu;
int n,m,s;
bool cmp(edge x,edge y){return x.w<y.w||x.w==y.w&&(x.v==s||x.u==s);}
pair<int,int>check(int k)
{
    for(int i=1;i<=m;i++)
        if(e[i].u==s||e[i].v==s)e[i].w-=k;
    sort(e+1,e+m+1,cmp);
    int res=0,cnt=0;
    dsu.init();
    for(int i=1;i<=m;i++)
    {
        int u=e[i].u,v=e[i].v,w=e[i].w;
        if(!dsu.same(u,v))res+=w,cnt+=(u==s||v==s?1:0),dsu.merge(u,v);
    }
    for(int i=1;i<=m;i++)
        if(e[i].u==s||e[i].v==s)e[i].w+=k;
    return make_pair(res,cnt);
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int need,cnt=0;
    cin>>n>>m>>s>>need;
    for(int i=1;i<=m;i++)
    {
        cin>>e[i].u>>e[i].v>>e[i].w;
        if(e[i].u==s||e[i].v==s)cnt++;
    }
    int l=-3e4-1,r=3e4+1;
    if(check(l).second>need||cnt<need)cout<<"Impossible";
    else
    {
        while(l+1<r)
        {
            int mid=l+r>>1;
            if(check(mid).second<=need)l=mid;
            else r=mid;
        }
        if(check(l).second<need)l++;
        cout<<check(l).first+need*l;
    }
    return 0;
}

忘情

化简题目的式子可以得到 (1+\sum_{i=1}^nx_i)^2,不难发现可以使用DP,但是还需要一维来枚举分的段数,时间复杂度为 O(n^2),考虑别的做法。

若我们不考虑分的段数,有DP转移方程如下:

dp_i=\min_{j=1}^idp_{j-1}+(1+sum_i-sum_{j-1})^2

其中 sum_i=\sum_{j=1}^ix_j

F(x) 为分 x 段时原问题最小的答案,可以证明其具有上凸性,故越往左分的段数就越少,不难发现给每次转移添加一个代价即可,此时代价越大取的段数就越少。设这个代价为 k,则现在的转移方程如下:

dp_{u}=\min_{j=1}^idp_{j-1}+(1+sum_i-sum_{j-1})^2+k

二分斜率 k,若切到多个节点,选择可能被更大的斜率切的节点即可。

接下来分析DP的部分。设 f(i)=1+sum_ig(i)=sum_{i-1},则原式变为

dp_{u}=\min_{j=1}^idp_{j-1}+(f(i)-g(j))^2+k\\ dp_{u}=\min_{j=1}^idp_{j-1}+f(i)^2+g(j)^2-2f(i)g(j)+k

假设当前的 i 可以由 j_1j_2 转移过来,满足 j_1<j_2,求什么情况下从 j_2 转移更优。

先列出转移方程:

dp_{j_1-1}+f(i)^2+g(j_1)^2-2f(i)g(j_1)+k>dp_{j_2-1}+f(i)^2+g(j_2)^2-2f(i)g(j_2)+k\\ dp_{j_1-1}+g(j_1)^2-2f(i)g(j_1)>dp_{j_2-1}+g(j_2)^2-2f(i)g(j_2)\\ dp_{j_1-1}+g(j_1)^2-dp_{j_2-1}-g(j_2)^2>2f(i)(g(j_1)-g(j_2))

不等式两边同时除以 g(j_1)-g(j_2),因为 g(j_1)-g(j_2)<0,故大于变小于。

\frac{dp_{j_1-1}+g(j_1)^2-dp_{j_2-1}-g(j_2)^2}{g(j_1)-g(j_2)}<2f(i)

y(i)=dp_{i-1}+g(i)^2x(i)=g(i),则原式转化为

\frac{y(j_1)-y(j_2)}{x(j_1)-x(j_2)}<2f(i)\\ \frac{y(j_2)-y(j_1)}{x(j_2)-x(j_1)}<2f(i)

也就是说当 j_1j_2 满足上述条件时,从 j_2 转移比从 j_1 转移更优,具体的