ABC418F 题解

· · 题解

题意

给定一个长度为 n 的整数数组 a,下标从 1 编号到 n

有一个 1 \times n 的网格图,每个格子要么放茶,要么放咖啡,有以下两个条件:

  1. 任意两个相邻格子,至少要有一个格子放茶。
  2. 对于任意 i,使得 a_i \neq -1,那么前 i 个格子强制恰好有 a_i 个格子放咖啡。

求有多少种合法的摆放方案。

分析性质

任意两个相邻格子,至少要有一个格子放茶。

这个条件等价于:任意两个放咖啡的格子不能相邻。

对于第二个条件,本质上,将一段前缀的格子划分成了若干个区间,每个区间的咖啡数量是确定的,剩下一段后缀的咖啡数量则不确定。

暴力做法

考虑左到右有 m 个格子,且咖啡数量确定为 k,那么方案数是多少。

这个问题可以转化为在 m 个连续位置里面选择 k 个两两不相邻的位置的方案数。

再次转化,发现这个问题等价于有 m-k 个物品,需要选择 k 个这些物品的间隙,分别插入一个新的物品,那么方案数就是 \binom{m-k+1}{k}

为了简化表示,令上述问题的答案为 g(m,k)

首先,仅考虑已经确定咖啡总数的那些区间。

假设第 i 个区间的长度为 l_i,咖啡数量为 x_i

那么设计出状态 f_{i,0/1} 表示前 i 个区间已经确定,且最后一个区间的最后一个位置是否为咖啡,那么总合法方案数是多少。

考虑状态转移,以下默认每种转移的结果都累加。

  1. 上一个区间的末尾没有咖啡,这次的区间末尾也没有咖啡。
  2. 上一个区间的末尾有咖啡,这次的区间末尾没有咖啡。
    此时分类讨论两种情况。
    • l_i \ge 2
    • l_i = 1
  3. 上一个区间的末尾没有咖啡,这次的区间末尾有咖啡。
    • l_i \ge 2
    • l_i = 1
  4. 上一个区间的末尾有咖啡,这次区间的末尾也有咖啡。
    • l_i \ge 3
    • l_i = 2
    • l_i = 1
      无解。

上述 DP 解决了由已经确定总咖啡数量的区间组成的那一段前缀。

但是还剩下一段后缀,没有咖啡数量限制,仅被咖啡不能相邻的条件约束。

如果设计出一种新的 DP 解决这一部分,会比较麻烦,考虑复用之前的 DP。

一种比较简单的方法,可以将剩下的那一段后缀划分为若干个单个元素组成的区间,沿用原来的 DP 状态。

对于状态转移,则枚举当前区间有一个咖啡还是有零个咖啡,分别按照之前的转移方式算出结果,加起来就是总方案数。

现在,成功构造出了一种单次 O(n) 的做法。

优化

每次修改,都是修改某位置的前缀和。

注意到,它实际上最多影响两个区间的总和。

另外注意到,原来的 DP 状态,完全可以写成矩阵的形式,通过矩阵乘法转移,这样就支持了线段树快速合并答案。

现在最大的问题,在于每次操作都可能增加一个区间,或删除一个区间,这样就没法直接维护每个区间的答案。

但这个也是可以简单解决的。

首先开一个 set 维护所有存在的区间。

对于每个区间,可以只在其中一个位置维护其整个区间的转移矩阵,区间其他位置全部填充为单位矩阵。

每次修改,都将可能影响的两个区间全部填充为单位矩阵,然后重新填就行了。

#include <bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(auto i=(a);i<=(b);i++)
#define REP(i,a,b) for(auto i=(a);i>=(b);i--)
#define FORK(i,a,b,k) for(auto i=(a);i<=(b);i+=(k))
#define REPK(i,a,b,k) for(auto i=(a);i>=(b);i-=(k))
#define pb push_back
#define mkpr make_pair
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef vector<int> vi;
template<class T>
void ckmx(T& a,T b){
    a=max(a,b);
}
template<class T>
void ckmn(T& a,T b){
    a=min(a,b);
}
template<class T>
T gcd(T a,T b){
    return !b?a:gcd(b,a%b);
}
template<class T>
T lcm(T a,T b){
    return a/gcd(a,b)*b;
}
#define gc getchar()
#define eb emplace_back
#define pc putchar
#define ep empty()
#define fi first
#define se second
#define pln pc('\n');
#define islower(ch) (ch>='a'&&ch<='z')
#define isupper(ch) (ch>='A'&&ch<='Z')
#define isalpha(ch) (islower(ch)||isupper(ch))
template<class T>
void wrint(T x){
    if(x<0){
        x=-x;
        pc('-');
    }
    if(x>=10){
        wrint(x/10);
    }
    pc(x%10^48);
}
template<class T>
void wrintln(T x){
    wrint(x);
    pln
}
template<class T>
void read(T& x){
    x=0;
    int f=1;
    char ch=gc;
    while(!isdigit(ch)){
        if(ch=='-')f=-1;
        ch=gc;
    }
    while(isdigit(ch)){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=gc;
    }
    x*=f;
}
void ioopti(){
    ios::sync_with_stdio(0);
    cin.tie(0);
}
const ll mod=998244353;
const int maxn=2e5+5;
int n,q;
struct Matrix{
    ll mat[2][2];
    void init(){
        memset(mat,0,sizeof mat);
    }
    void unit(){
        init();
        mat[0][0]=mat[1][1]=1;
    }
    Matrix operator*(const Matrix& b)const{
        Matrix res;
        res.init();
        FOR(i,0,1){
            FOR(k,0,1){
                FOR(j,0,1){
                    (res.mat[i][j]+=mat[i][k]*b.mat[k][j])%=mod;
                }
            }
        }
        return res;
    }
    Matrix operator+(const Matrix& b)const{
        Matrix res;
        FOR(i,0,1){
            FOR(j,0,1){
                (res.mat[i][j]=mat[i][j]+b.mat[i][j])%=mod;
            }
        }
        return res;
    }
};
struct SGT{
    int lef[maxn<<2],rig[maxn<<2];
    Matrix res[maxn<<2],cov[maxn<<2];
    bool tag[maxn<<2];// 本节点存在 lazytag
    void pushup(int pos){
        res[pos]=res[pos<<1]*res[pos<<1|1];
    }
    void upd(int pos,Matrix _cov,bool _tag){
        tag[pos]=_tag;
        cov[pos]=_cov;
        res[pos]=_cov;
        // 理论上应该是 qpow(_cov,r-l+1)
        // 但是,赋值要么是单位矩阵,要么是区间长度为 1
    }
    void pushdown(int pos){
        if(!tag[pos])return;
        upd(pos<<1,cov[pos],tag[pos]);
        upd(pos<<1|1,cov[pos],tag[pos]);
        tag[pos]=0;
    }
    void bd(int pos,int l,int r){
        lef[pos]=l,rig[pos]=r,res[pos].unit(),tag[pos]=0;
        if(l==r)return;
        int mid=(l+r>>1);
        bd(pos<<1,l,mid);
        bd(pos<<1|1,mid+1,r);
    }
    void chg(int pos,int l,int r,Matrix _cov){
        if(lef[pos]>=l&&rig[pos]<=r){
            upd(pos,_cov,1);
            return;
        }
    //  printf("pos %d\n",pos);
        pushdown(pos);
        int mid=(lef[pos]+rig[pos]>>1);
        if(mid>=l)chg(pos<<1,l,r,_cov);
        if(mid<r)chg(pos<<1|1,l,r,_cov);
        pushup(pos);
    }
}sgt;
ll qpow(ll a,ll b){
    ll ret=1;
    for(;b;a=a*a%mod,b>>=1){
        if(b&1)ret=ret*a%mod;
    }
    return ret;
}
ll inv(ll a){
    return qpow(a,mod-2);
}
ll fact[maxn*2],factinv[maxn*2];
ll C(int n,int m){
//  printf("C(%d,%d) = %lld\n",n,m,((n>=m)?fact[n]*factinv[m]%mod*factinv[n-m]%mod:0));
    return ((n>=m)?fact[n]*factinv[m]%mod*factinv[n-m]%mod:0);
}
ll F(int len,int x){
    if(len<0||x<0)return 0;
    return C(len-x+1,x);
}
Matrix calc(int len,int x){
    Matrix res;
    res.init();
    // 原来末尾没有,现在末尾没有
    {
        res.mat[0][0]=F(len-1,x);
    }
    // 原来末尾有,现在末尾没有
    {
        if(len>=2){
            res.mat[1][0]=F(len-2,x);
        }else{
            // 第一个排除位置,与最后一个排除位置重合,实际只关心唯一一个位置是否不填
            res.mat[1][0]=(x==0);
        }
    }
    // 原来末尾没有,现在末尾有
    {
        if(len>=2){
            res.mat[0][1]=F(len-2,x-1);
        }else{
            res.mat[0][1]=(x==1);   
        }
    }
    // 原来末尾有,现在末尾有
    {
        if(len>=3){
            res.mat[1][1]=F(len-3,x-1);
        }else{
            res.mat[1][1]=(len==2&&x==1);
        }

    }
    return res;
}
set<int> poslist; 
int pre[maxn];
Matrix ways_nolim[maxn];// 没有咖啡数量的限制,仅有咖啡不能相邻的限制
Matrix unitmat;
void print(){
    Matrix ans;
    ans.init();
    ans.mat[0][0]=1;
    ans=ans*sgt.res[1];
    auto ptr=prev(poslist.end());
    ans=ans*ways_nolim[n-*ptr];
    printf("%lld\n",((ans.mat[0][0]+ans.mat[0][1])%mod+mod)%mod);
}
void solve(int id_of_test){
    read(n);
    read(q);
    vector<pii> queries;
    FOR(i,1,q){
        int x,y;
        read(x);
        read(y);
        queries.eb(mkpr(x,y));
    }
    sgt.bd(1,1,n);
    poslist.insert(0);
//  print();
//  printf("ways_no_lim %lld %lld\n",ways_nolim[n].mat[0][0],ways_nolim[n].mat[0][1]);
//  return;
    for(auto [pos,_val]:queries){
        if(pre[pos]==_val){
            print();            
            continue;
        }
        if(pre[pos]==-1){
            // 增加新的右端点
            pre[pos]=_val;
            auto _rg=poslist.lower_bound(pos);
            auto _lf=_rg;
            _lf--;
            int lf=*_lf+1,rg;
            sgt.chg(1,lf,pos,unitmat);// 删除 [lf,pos]
            if(_rg!=poslist.end()){
                rg=*_rg;
                {
                    // [lf,rg] 是原来 pos 所在的完整区间
                    // 首先删除原来区间      
                    sgt.chg(1,lf,rg,unitmat);
                }
                if(pos!=rg){
                    // 加入 [pos+1,rg]
                    sgt.chg(1,pos+1,pos+1,calc(rg-(pos+1)+1,pre[rg]-pre[pos]));
                }
            }
            // 加入 [lf,pos]
            sgt.chg(1,lf,lf,calc(pos-lf+1,pre[pos]-pre[lf-1]));
            poslist.insert(pos);
        }else if(_val==-1){
            // 减少一个右端点
            pre[pos]=-1;
            auto _rg=poslist.find(pos);
            auto _lf=_rg;
            _lf--;
            _rg++;
            int lf=*_lf+1;
            int rg;
            sgt.chg(1,lf,pos,unitmat);
            if(_rg!=poslist.end()){// 存在下一个右端点,提供合并
                rg=*_rg;
                // 删除原来区间
                sgt.chg(1,lf,rg,unitmat);
                // 加入新的
                if(pre[rg]!=-1)sgt.chg(1,lf,lf,calc(rg-lf+1,pre[rg]-pre[lf-1]));
            }else{
                sgt.chg(1,lf,n,unitmat);// 否则,后面整段都没了。
            }
            poslist.erase(pos);
        }else{
            // 简单地修改值
            pre[pos]=_val;
            auto _rg=poslist.find(pos);
            auto _lf=_rg;
            _lf--;
            _rg++;
            int lf=*_lf+1;
            int rg;
            sgt.chg(1,lf,pos,unitmat);
            if(_rg!=poslist.end()){
                rg=*_rg;
                sgt.chg(1,lf,rg,unitmat);
                if(pre[rg]!=-1)sgt.chg(1,rg,rg,calc(rg-(pos+1)+1,pre[rg]-pre[pos]));
            }
            sgt.chg(1,lf,lf,calc(pos-lf+1,pre[pos]-pre[lf-1]));
        }
        print();
    }
}
int main()
{
    unitmat.unit();
    fact[0]=1;
    FOR(i,1,maxn*2-1)fact[i]=fact[i-1]*i%mod;
    factinv[maxn*2-1]=inv(fact[maxn*2-1]);
    REP(i,maxn*2-1,1)factinv[i-1]=factinv[i]*i%mod;
    ways_nolim[0].unit();
    auto sgway=calc(1,0)+calc(1,1);
//  return 0;
    FOR(i,1,maxn-1){
        ways_nolim[i]=ways_nolim[i-1]*sgway;
    }
    FOR(i,1,maxn-1){
        pre[i]=-1;
    }
    int T;
    T=1;
    FOR(_,1,T){
        solve(_);
    }
    return 0;
}
/*
1. 对题意的理解能否和样例对的上?
2. 每一步操作,能否和自己的想法对应上?
3. 每一步操作的正确性是否有保证?
4. 是否考虑到了所有的 case?特别是极限数据。
5. 变量的数据类型是否与其值域匹配?
6. 时间复杂度有保证吗?
7. 空间多少 MB?
*/