题解【[六校NOIP2021#63]蘄壺螸壚黜】

· · 个人记录

题目链接

我们先来考虑一条链怎么做。定义 f(i,j) 表示到第 i 位,末尾填 j 的方案数,s(i) 为到第 i 位的总方案数。转移方程:f(i,j)=s(i-1)-f(i-1,j)

直接转移是 n^2 的,我们可以用线段树优化,存的是每一个 j 的方案数。转移时直接全局乘 -1 再加上 s(i-1)

我们再来考虑环。先把环断成链,然后只要减去末尾跟头相等的个数。我们可以把 a_i 最小的当成链头,这样末尾肯定可以取到跟头相同的颜色。

下面是 AC 代码。

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define lson o<<1
#define rson o<<1|1
#define mid (l+r>>1)

const int M = 998244353;

int n,m,ans=1,a[200005],b[200005],d[200005];
int ee,h[200005],nex[200005<<1],to[200005<<1];
int dsu[200005],c[200005];
int root,vis[200005];
struct segtree{
    int v,laz1,laz2;
}t[200005<<2];

int ksm(int x,int w)
{
    int s=1;
    while(w)
    {
        if(w&1)
            s = s*x%M;
        x = x*x%M;
        w >>= 1;
    }
    return s;
}

int find_(int x)
{
    return dsu[x]==x ? x : dsu[x] = find_(dsu[x]);
}

void union_(int x,int y)
{
    x = find_(x), y = find_(y);
    if(x!=y)
        dsu[y] = dsu[x];
    else if(x==y)
        c[dsu[x]] = 1;
}

void addedge(int x,int y)
{
    nex[++ee] = h[x], to[ee] = y, h[x] = ee;
}

void getroot(int x,int pre)
{
    if(root>0) return;
    int cnt=0;
    for(int i=h[x];i;i=nex[i])
        if(to[i]!=pre)
            cnt++, getroot(to[i],x);
    if(cnt==0)
        root = x;
}

void dfs(int x)
{
    b[++m] = a[x], vis[x] = 1;
    for(int i=h[x];i;i=nex[i])
        if(!vis[to[i]])
            dfs(to[i]);
}

void update(int o)
{
    t[o].v = (t[lson].v+t[rson].v+M)%M;
}

void build(int o,int l,int r)
{
    t[o].v = 0, t[o].laz1 = 1, t[o].laz2 = 0;
    if(l==r)
    {
        t[o].v = 0;
        return;
    }
    build(lson,l,mid), build(rson,mid+1,r);
    update(o);
}

void pushdown(int o,int l,int r)
{
    if(t[o].laz1!=1)
    {
        t[lson].v = t[lson].v*t[o].laz1%M;
        t[lson].laz1 = t[lson].laz1*t[o].laz1%M;
        t[lson].laz2 = t[lson].laz2*t[o].laz1%M;
        t[rson].v = t[rson].v*t[o].laz1%M;
        t[rson].laz1 = t[rson].laz1*t[o].laz1%M;
        t[rson].laz2 = t[rson].laz2*t[o].laz1%M;
    }
    if(t[o].laz2!=0)
    {
        t[lson].v = (t[lson].v+t[o].laz2*(mid-l+1))%M;
        t[lson].laz2 = (t[lson].laz2+t[o].laz2)%M;
        t[rson].v = (t[rson].v+t[o].laz2*(r-mid))%M;
        t[rson].laz2 = (t[rson].laz2+t[o].laz2)%M;
    }
    t[o].laz1 = 1, t[o].laz2 = 0;
}

void modify1(int o,int l,int r,int ql,int qr,int v)
{
    if(ql>qr) return;
    if(l>=ql && r<=qr)
    {
        t[o].v = t[o].v*v%M;
        t[o].laz1 = t[o].laz1*v%M, t[o].laz2 = t[o].laz2*v%M;
        return;
    }
    pushdown(o,l,r);
    if(ql<=mid)
        modify1(lson,l,mid,ql,qr,v);
    if(qr>mid)
        modify1(rson,mid+1,r,ql,qr,v);
    update(o);
}

void modify2(int o,int l,int r,int ql,int qr,int v)
{
    if(ql>qr) return;
    if(l>=ql && r<=qr)
    {
        t[o].v = (t[o].v+v*(r-l+1))%M;
        t[o].laz2 = (t[o].laz2+v)%M;
        return;
    }
    pushdown(o,l,r);
    if(ql<=mid)
        modify2(lson,l,mid,ql,qr,v);
    if(qr>mid)
        modify2(rson,mid+1,r,ql,qr,v);
    update(o);
}

int query(int o,int l,int r,int x)
{
    if(l==x && r==x)
        return t[o].v;
    pushdown(o,l,r);
    if(x<=mid)
        return query(lson,l,mid,x);
    else
        return query(rson,mid+1,r,x);
}

int calc1(int n)
{
    m = 0;
    for(int i=1;i<=n;i++)
        m = max(m,b[i]);
    modify1(1,1,m,1,m,0);
    modify2(1,1,m,1,b[1],1);
    for(int i=2;i<=n;i++)
    {
        int sum=t[1].v;
        modify1(1,1,m,1,b[i],-1), modify1(1,1,m,b[i]+1,m,0);
        modify2(1,1,m,1,b[i],sum);
    }
    return t[1].v%M;
}

int calc2(int n)
{
    m = 0;
    int p=0,len=0;
    for(int i=1;i<=n;i++)
    {
        m = max(m,b[i]);
        if(p==0 || b[i]<b[p])
            p = i;
        d[i] = b[i];
    }
    for(int i=p;i<=n;i++)
        b[++len] = d[i];
    for(int i=1;i<p;i++)
        b[++len] = d[i];
    modify1(1,1,m,1,m,0);
    modify2(1,1,m,1,1,1);
    for(int i=2;i<=n;i++)
    {
        int sum=t[1].v;
        modify1(1,1,m,1,b[i],-1), modify1(1,1,m,b[i]+1,m,0);
        modify2(1,1,m,1,b[i],sum);
    }
    return (t[1].v-query(1,1,m,1)+M)%M*b[1]%M;
}

signed main()
{
    freopen("rainbow.in","r",stdin);
    freopen("rainbow.out","w",stdout);
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<=n;scanf("%lld",a+i),i++);
    for(int i=1;i<=n;i++) dsu[i] = i;
    for(int i=1,x,y;i<=m&&scanf("%lld%lld",&x,&y);i++)
    {
        if(x==y)
            return puts("0");
        addedge(x,y), addedge(y,x), union_(x,y);
    }
    for(int i=1;i<=n;i++)
        if(find_(i)==i)
        {
            m = 0;
            if(c[i]==0)
            {
                root = 0;
                getroot(i,0), dfs(root);
                ans = ans*calc1(m)%M;
            }
            else if(c[i]==1)
            {
                dfs(i);
                ans = ans*calc2(m)%M;
            }
        }
    cout<<ans%M<<endl;

    return 0;
}

祝大家 AC 愉快!