[NOIP1997 普及组] 棋盘问题(加强版) 简单记录

· · 题解

原题 n,m 被开到 10^{12} 出到了模拟赛,我推出了一个很麻烦的做法,但是能过,记录一下。

正方形:\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m}(\min(n-i+1,m-j+1))
长方形(包括正方形):\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m}(n-i+1)(m-j+1) = \sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m} (nm-nj+n-im+ij-i+m-j+1) = \sum\limits_{i=1}^{n}\sum\limits_{j=1}^{m}((nm+n+m)+(-nj-im)+(-i-j)+(ij)+1)

最后复杂度 O(1)

#include <bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define REP(i,a,b) for(int i=a;i>=b;i--)
#define pb push_back
#define mkpr make_pair
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
int T;
#define ll __int128
const ll mod=1e9+7;
ll n,m;
ll oldn,oldm;
ll minnm;
ll qpow(ll a,ll b){
    ll ret=1;
    while(b){
        if(b&1)ret=ret*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ret;
}
ll inv(ll a){
    return qpow(a,mod-2);
}
ll inv2,inv6;
ll solvechang(){
    ll bra1=(n*m+n+m)%mod*(n*m%mod)%mod;
    ll bra3_i=((1ll+n)*(n)/2ll)%mod*m%mod;
    ll bra3_j=((1ll+m)*m/2ll)%mod*n%mod;
    ll bra3=-bra3_i-bra3_j;
    ll bra2=-(bra3_i*m%mod+bra3_j*n%mod)%mod;
    ll sum_1n=((1ll+n)*n%mod*inv2)%mod;
  //  printf("tmp = %lld\n",(1ll+n)*n%mod);
 //   printf("sum1n = %lld sum1n*m = %lld m = %lld\n",sum_1n,sum_1n*m,m);
    ll bra4=(sum_1n+sum_1n*m)%mod*m%mod*inv2%mod;
    ll bra5=n*m%mod;
 //   printf("%lld %lld %lld %lld %lld\n",bra1,bra2,bra3,bra4,bra5);
    return bra1+bra2+bra3+bra4+bra5;
}
ll sum_11_nn(ll n){
    return n*(n+1)%mod*(2ll*n+1)%mod*inv6%mod;
}
ll sum(ll n){
    return (1ll+n%mod)*(n%mod)*inv2%mod;
}
ll solverect(){
    ll bra1=0,bra2=0;
    // 第一个是序列 1 对 min 的贡献,第二个是序列 2 的
    if(oldn<oldm){
        // m*1,(m-1)*2,(m-2)*3,(m-3)*4,....(m-n+1)*n
        // split to 
        // m*sum(1,n) - (0*1 + 1*2 + 2*3 + ... + (n-1)*n)
        bra1+=(m*sum(n)%mod-(sum_11_nn(n)-sum(n)))%mod;
    }else{
        // n>=m
        // 1*m,2*(m-1),3*(m-2),...,m*(m-(m-1))
        // split to m*sum(1,m) - (0*1+1*2+2*3+...+(m-1)*m)
        bra1+=m*sum(m)-(sum_11_nn(m)-sum(m));
    }
    if(oldm<oldn){
        // (n-1)*1,(n-2)*2+(n-3)*3...+(n-m)*m 
        // split to n*sum(1,m) - (1*1+2*2+3*3+...+m*m);
        bra2+=sum(m)*n%mod-sum_11_nn(m);
    }else{
        // m>=n
        // (n-1)*1,(n-2)*2+(n-3)*3+...+(n-(n-1))*(n-1)
        bra2+=n*sum(n-1)%mod-(sum_11_nn(n-1));
    }
  //  printf("%lld %lld\n",bra1,bra2);
    return bra1+bra2;
}
void solve(){
    inv2=inv(2);
    inv6=inv(6);
    //printf("inv2 = %lld\n",inv2);
    long long _n,_m;
    scanf("%lld%lld",&_n,&_m);
    n=_n,m=_m;
    oldn=n,oldm=m;
    minnm=min(n,m);
    minnm%=mod;
    n%=mod;
    m%=mod;
  //  printf("n = %lld m = %lld\n",n,m);
    ll changans=solvechang(),rectans=solverect();
    changans-=rectans;
    changans=(changans%mod+mod)%mod;
    rectans=(rectans%mod+mod)%mod;
    printf("%lld %lld\n",(long long)rectans,(long long)changans);
}
int main()
{
    T=1;
    while(T--){
        solve();
    }
    return 0;
}