题解:P12401 [COI 2025] 玻利维亚 / Bolivija

· · 题解

Q = 0 时的解法(O(n + \mathrm{mx}) 暴力)

首先,如果一个 (A,B) 是好的,那么显然 (A,B) 之间的每一层都是对称的。(就题面中的图而言,从下往上的第 1,3,4,6,7 层是对称的,而第 2,5 层不是。)如果一对柱子高度不相等,那么它们就会让它们之间的高度不对称。

fail_i 为第 i 层有多少对柱子不对称,那么显然,对于柱子 j,它和与它相对的柱子 n-j+1 会让 i \in (\min(v_j,v_{n-j+1}),\max(v_j,v_{n-j+1})] 中的 fail_i 都增加 1

(对于样例 2fail 数组的值即为 0,1,0,0,1,0,0。)

那么,要统计的答案即为“ fail 中只含有数字 0 的子区间个数”。如果 fail 中有极长的一个区间 [l,r] 满足 \forall i \in [l,r], fail_i = 0 并且 fail_{l-1} \neq 0fail_{r+1} \neq 0[l,r] 之间的所有子区间就都满足要求。因此 [l,r] 对答案的贡献为 T(r-l+1)。(在本题解中设 T(x) = \sum_{i=1}^{i \le x} i = \frac{x(x+1)}{2})。

(在样例 2fail 数组中这样的区间分别为 [1,1],[3,4],[6,7],因此答案为 T(1-1+1) + T(4-3+1) + T(7-6+1) = 1+3+3 = 7

于是我们可以在本题中解决 Q = 0 时的情况,获得 32 分的高分。

代码:

#include <bits/stdc++.h>
#define _eggy_ using
#define _party_ namespace
_eggy_ _party_ std;

#define sjm 654210
int n,m,i;
int a[200100], c[sjm], fail[sjm]; // fail[i]表示第i层有多少个柱子不对称 
long long ans,len0;
#define T(x) (((x)*((x)+1))>>1)
int main()
{
    scanf("%d%d",&n,&m);
    for (i = 1; i <= n; i++) scanf("%d",a+i);
    for (i = 1; (i << 1) <= n; i++)
    {
        c[min(a[i],a[n-i+1])+1]++;
        c[max(a[i],a[n-i+1])+1]--;
    }for (i = 1; i <= a[(n+1)>>1]; i++)
    {
        fail[i] = fail[i-1] + c[i];

        if (!fail[i]) len0++;
        else
        {
            ans += T(len0);
            len0 = 0;
        }
    }ans += T(len0);
    printf("%lld",ans);
    return 0;
}

正解(O(n + Q \log \mathrm{mx})

题面翻译

v 的数字变化时,fail 数组也会变化。具体来说,我们可以把 (\min(v_x,v_{n-x+1}),\max(v_x,v_{n-x+1})] 减少 1,修改后再在新的区间中增加 1

于是我们完成了本题的中译中:

给出一个非负整数数列 fail,要求支持以下 2 种操作:

  1. 给出 l,r,v,将区间 [l,r] 内每个数增加 vv \in \{1,-1\}
  2. 查询整个数列内有多少个区间内部只有 0

现在可以开始做题了。

正解思路

操作 1 告诉我们要用线段树。考虑怎么合并询问的这个离谱信息。

容易想到,把左右儿子的合法区间数(设为 ans)相加。但还有一些区间跨越了中点。如果左儿子的最右端和右儿子的最左端都是 0 的话,就会出现这样的区间。因此我们还需要记录线段树上每个区间从左端点往右最长连续 0 的个数(设为 llen),和从右端点往左最长连续 0 的个数(设为 rlen)。

但是这样会产生一个严重的问题:如果一个区间内的数没有 0,但是有很多 1,那么上面的三个变量统统都是 0!这时如果对区间整体减少 1,那么求出正确的 ans 就不得不递归计算它的所有儿子,这样的时间复杂度将会是 O(\mathrm{mx})

再想想,对于每一个 k,维护区间减少 k 时的答案,但是这样更不可行。即使是用 map 只维护区间内出现的数,也要在每次合并时遍历整个 map

经过三节课的思考后……

给出一个非负整数数列……

我们只需要维护当区间的最小值被减成 0 时的答案即可!这是因为,当该区间最小值被减成 0 后该区间就不可能再整体减少,因为这会使最小值变成负数。

于是我们还要在线段树上维护区间最小值(设为 minn)。为了判断区间合并时的交界处和区间最小值是否一样,我们还需要维护区间最左边的数(设为 lis)和最右边的数(设为 ris)。相应地,llenrlen 的定义分别被改为从左到右/从右到左最长连续 lis/ris 的个数。

AC code

#include <bits/stdc++.h>
#define _eggy_ using
#define _party_ namespace
_eggy_ _party_ std;

int n,m,i,k;

#define sjm 654210 //(萨哈马火山的高度为 654200cm)
int a[200100], c[sjm], fail[sjm]; // fail[i]表示第i层有多少个位置不对称 
long long ans,len0;
#define T(x) (((x)*((x)+1ll))>>1) // x 可以为int 

#define cint const int&
template<unsigned char ce> //关于编译期确定线段树结构的方法 :查询 https://www.luogu.com.cn/article/u3r9m9nj 
struct dzpd
{
    struct dzpd <ce-1> ls,rs;
    int l,r,lis,llen,ris,rlen,added,minn;
    long long ans;
    void up()
    {
        lis = ls.lis, ris = rs.ris;
        llen = ls.llen, rlen = rs.rlen;
        if (ls.lis == rs.lis && ls.llen == ls.r - ls.l + 1) 
            llen += rs.llen;
        if (rs.ris == ls.ris && rs.rlen == rs.r - rs.l + 1) 
            rlen += ls.rlen;

        if (ls.minn < rs.minn)
        {
            minn = ls.minn;
            ans = ls.ans;
        }
        else if(ls.minn > rs.minn)
        {
            minn = rs.minn;
            ans = rs.ans;
        }else
        {
            minn = ls.minn;
            ans = ls.ans + rs.ans;
            if (ls.ris == minn && rs.lis == minn)
            {
                ans -= T(ls.rlen);
                ans -= T(rs.llen);
                ans += T(ls.rlen + rs.llen);
            }
        }
    }
    void down()
    {
        if (!added) return;
        ls.lis += added, ls.ris += added, 
        ls.minn += added, ls.added += added;
        rs.lis += added, rs.ris += added, 
        rs.minn += added, rs.added += added;
        added = 0;
    }
    void bd(cint bdl, cint bdr) //建树
    {
        l = bdl, r = bdr, added = 0;
        if (bdl == bdr)
        {
            lis = ris = minn = fail[bdl];
            llen = rlen = ans = 1;
            return;
        }
        ls.bd(bdl,(bdl+bdr)>>1);
        rs.bd(((bdl+bdr)>>1)+1,bdr); up();
    }
    void add(cint t1, cint t2, cint v) //v = ±1 
    {
        if (t1 <= l && r <= t2)
        {
            lis += v, ris += v, minn += v, added += v;
            return;
        }down();
        if (t1 <= ls.r) ls.add(t1,t2,v);
        if (t2 >= rs.l) rs.add(t1,t2,v); up();
    }
};
template<>
struct dzpd<0>
{
    int l,r,lis,llen,ris,rlen,added,minn;
    long long ans;
    void bd(cint bdl, cint bdr)
    {
        l = bdl, r = bdr, added = 0;
        lis = ris = minn = fail[bdl];
        llen = rlen = ans = 1;
    }
    void add(cint t1, cint t2, cint v)
    {lis += v, ris += v, minn += v, added += v;}
}; struct dzpd <20> t;

int main()
{
    scanf("%d%d",&n,&m);
    for (i = 1; i <= n; i++) scanf("%d",a+i);
    for (i = 1; (i << 1) <= n; i++)
    {
        c[min(a[i],a[n-i+1])+1]++;
        c[max(a[i],a[n-i+1])+1]--;
    }
    for (i = 1; i <= a[(n+1)>>1]; i++)
        fail[i] = fail[i-1] + c[i];
    t.bd(1,a[(n+1)>>1]);
    if (t.minn == 0) printf("%lld\n",t.ans);
    else puts("0");
    while(m--)
    {
        scanf("%d%d",&i,&k);
        if (a[i] != a[n-i+1]) t.add(min(a[i],a[n-i+1])+1,max(a[i],a[n-i+1]),-1); a[i] = k;
        if (a[i] != a[n-i+1]) t.add(min(a[i],a[n-i+1])+1,max(a[i],a[n-i+1]),+1);
        if (t.minn == 0) printf("%lld\n",t.ans);
        else puts("0");
    }
    return 0;
}