[ABC259Ex] Yet Another Path Counting

· · 个人记录

\text{Links}

[ABC259Ex] Yet Another Path Counting

cnblogs

题外话

题意

给一个 n\times n 的网格图,每个格子上有一个颜色。

每一步只能往右或者往下走,问有多少条路径的起点和终点的颜色相同,对 998244353 取模。

------------ ### 题解 不同颜色的统计互不干扰,所以按颜色分开来统计。 考虑有用的信息只有起点和终点的颜色,所以枚举点对,组合数计算贡献,即 $y2-y1+x2-x1\choose x2-x1$。复杂度为 $O(siz^2)$,其中 $siz$ 为这种颜色的点数。 发现如果同种颜色的点数过大的话这个做法会 G。并且很难维护合并组合数的计算来降低复杂度。 但是点数有一个限制,即所有颜色的点数加起来为 $n^2$,于是可以考虑根号分治了! 设置阈值 $T$,当 $siz\le T$ 时,直接用上面的暴力做法,此部分总时间复杂度为 $O(\frac{n^2}{T}\times T^2)$,即 $O(n^2T)$。 当 $siz\gt T$ 时,这样的颜色最多只有 $\frac{n}{T}$ 种,那么对于每个颜色再搞个暴力做法。 考虑,这个暴力做法时间复杂度的正确性应该是不依赖于 $siz$ 的,不然我们根分有什么用呢?全部用这个做法不就好了吗。 所以考虑 $O(n^2)$ 的 $dp$,钦定我们当前 $solve$ 的颜色为 $col$。设 $dp_{i,j}$ 表示从颜色为 $col$ 的格子走到位置 $(i,j)$ 的方案数。 转移很简单:$dp_{i,j}=dp_{i,j-1}+dp_{i-1,j}+[a_{i,j}=col]$。 于是每个位置 $(i,j)$ 对 $ans$ 的贡献应该是 $[a_{i,j}=col]\times f_{i,j}$。此部分总时间复杂度为 $O(\frac{n^2}{T}\times n^2)$,即 $O(\frac{n^4}{T})$。 然后就做完了。取 $T=n$ 的时候达到平衡,总时间复杂度为 $O(n^3)$。 代码非常简单。(由于不怎么习惯用大量的 $pair$,所以这篇码风可能比较诡异) ------------- ### $\text{Code}
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define il inline
#define re register
const int N=405,T=400,mod=998244353;
int n,a[N][N],ans,fac[N<<1],inv[N<<1],invfac[N<<1],f[N][N];
#define pii pair<int,int>
#define mp make_pair
vector<pii >v[N*N];
il int read(){
    re int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}
il void Add(int &x,int y){
    x=(x+y)%mod;
}
il int C(int n,int m){
    if(n<0||m<0||n<m)return 0;
    return fac[n]*invfac[m]%mod*invfac[n-m]%mod;
}
il bool cmp(pii x,pii y){
    return y.first>=x.first&&y.second>=x.second;
}
il int disx(pii x,pii y){
    return y.first-x.first;
}
il int disy(pii x,pii y){
    return y.second-x.second;
}
#define nowi v[col][i]
#define nowj v[col][j]
il void solve1(int col){
    int siz=(int)v[col].size();
    for(re int i=0;i<siz;i++)
    for(re int j=i;j<siz;j++)
        if(cmp(nowi,nowj))Add(ans,C(disx(nowi,nowj)+disy(nowi,nowj),disx(nowi,nowj)));
}
il void solve2(int col){
    for(re int i=1;i<=n;i++)
        for(re int j=1;j<=n;j++){
            f[i][j]=(f[i-1][j]+f[i][j-1]+(a[i][j]==col))%mod;
            if(a[i][j]==col)Add(ans,f[i][j]);
        }
}
il void GetInv(){
    inv[1]=fac[1]=invfac[1]=fac[0]=invfac[0]=1;
    for(re int i=2;i<=(n<<1);i++){
        inv[i]=inv[mod%i]*(mod-mod/i)%mod;
        fac[i]=fac[i-1]*i%mod;
        invfac[i]=invfac[i-1]*inv[i]%mod;
    }
}
signed main(){
    n=read();
    GetInv();
    for(re int i=1;i<=n;i++)
        for(re int j=1;j<=n;j++)
            a[i][j]=read(),v[a[i][j]].push_back(mp(i,j));
    for(re int col=1;col<=n*n;col++){
        if(v[col].empty())continue;
        if((int)v[col].size()<=T)solve1(col);
        else solve2(col);
    }
    cout<<ans;
    return 0;
}