P3166

· · 个人记录

[CQOI2014]数三角形

数学题。

然后容斥,任选三点的方案 - 三点共线的方案。

那么任选三点方案为 C_{(n+1)(m+1)}^3

三点共线有一点点麻烦。

首先在同一行或同一列很好搞,方案为 (n+1)C_{m+1}^3+(m+1)C_{n+1}^3

在同一个斜线上并不好做。

我们可以随意枚举两点,然后这两点之间的线段上的点的个数就是这两个点可以给答案做得共线,而且是能做到不重不漏的。

但是枚举两点复杂度较高,我们干脆采用枚举两点横纵坐标之差来计数。

然后这个计数的规律可以瞎猜,或者严谨一点地瞎猜。

中间去掉两个端点的整点的个数为 \gcd(x,y)-1 ,其中 x,y 分别为横纵坐标之差。

上面不过是一个结论,现在来看如何证明。

我之前想不清楚这个结论主要是因为绕在 \gcd(x,y) 是长度的思维定式的怪圈了。

实际上,可以把 \gcd(x,y)=d 看作把横纵坐标差平分为 d 份,这样的话整格点的个数就是 \gcd(x,y)-1 了(除去线段两端点)。

然后又要考虑到这样的线段在平面内可以平移,所以说往左右平移有 (n-i+1) 种可能,往上下平移又有 (m-j+1) 种可能,那么根据乘法原理相乘即可。

然后又要考虑,我们刚才假设的直线都是斜率为正,还要把斜率为负再考虑一遍,乘上 2 即可。

那么最后答案就是:

ans=C_{(n+1)(m+1)}^3-[(n+1)C_{m+1}^3+(m+1)C_{n+1}^3]-2\sum_{i=1}^n\sum_{j=1}^m(n-i+1)(m-j+1)(\gcd(i,j)-1)

这样计算的话,时间复杂度为 O(nm\log (n+m)) ,已经可以通过本题。

然后就看到了更神奇的 O(n) 做法。

根据欧拉函数的计算公式以及等比数列的求和公式,经过推导之后可以得到如下结论(即欧拉反演):

\sum_{d|n}\varphi(d)=n

n 替换为 \gcd(i,j) 即是:

\sum_{d|\gcd(i,j)}\varphi(d)=\gcd(i,j)

那么我们把这个代入回原式:

ans=C_{(n+1)(m+1)}^3-[(n+1)C_{m+1}^3+(m+1)C_{n+1}^3]-2\sum_{i=1}^n\sum_{j=1}^m(n-i+1)(m-j+1)(\sum_{d|\gcd(i,j)}\varphi(d)-1)

因为 \varphi(1)=1 ,我们再把 1 剔除:

ans=C_{(n+1)(m+1)}^3-[(n+1)C_{m+1}^3+(m+1)C_{n+1}^3]-2\sum_{i=1}^n\sum_{j=1}^m(n-i+1)(m-j+1)(\sum_{d|\gcd(i,j)}^(d\not=1)\varphi(d))

然后我们转换一下思维。这里相当于先枚举 i,j ,再寻找 i,j 的约数,效率是较低的。但是我们先枚举除 1 以外的约数 d ,可以肯定 d\in[2,\min\{n,m\}]

然后我们可以通过枚举这个约数的倍数来确定 i,j ,相当于把顺序调换过来。

那么这个答案可以表示为:

ans=C_{(n+1)(m+1)}^3-[(n+1)C_{m+1}^3+(m+1)C_{n+1}^3]-2\sum_{d=2}^{\min\{n,m\}}\varphi(d)\sum_{i=1}^{\lfloor \frac{n}{d}\rfloor}\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}(n-id+1)(m-jd+1)

在欧拉筛的过程中顺便求出欧拉函数。

这个推导利用了下面两个结论:

  1. 若质数 p|n,p^2|n ,则 \varphi(n)=\varphi(\frac{n}{p})p

  2. 若质数 p|n,p^2\nmid n ,则 \varphi(n)=\varphi(\frac{n}{p})(p-1)

然后我们继续探讨上式,我们发现它可以这样变化:

ans=C_{(n+1)(m+1)}^3-[(n+1)C_{m+1}^3+(m+1)C_{n+1}^3]-2\sum_{d=2}^{\min\{n,m\}}\varphi(d)\sum_{i=1}^{\lfloor \frac{n}{d}\rfloor}(n-id+1)\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}(m-jd+1)

然后我们发现 \sum_{i=1}^{\lfloor \frac{n}{d}\rfloor}(n-id+1)\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}(m-jd+1) 是等差数列求和。

所以上式又可以变为:

ans=C_{(n+1)(m+1)}^3-[(n+1)C_{m+1}^3+(m+1)C_{n+1}^3]-\frac{1}{2}\sum_{d=2}^{\min\{n,m\}}\varphi(d)\sum_{i=1}^{\lfloor \frac{n}{d}\rfloor}(n-d+n\bmod d+2)\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}(m-d+m\bmod d+2)

然后就可以 O(n) 求了。

代码(O(n)):

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=2e3;

ll n,m,ans,cnt;

ll prime[N+5],phi[N+5];

bool f[N+5];

void init() {
    f[1]=1;phi[1]=1;
    for(ll i=2;i<=n||i<=m;i++) {
        if(!f[i]) {prime[++cnt]=i;phi[i]=i-1;}
        for(ll j=1;j<=cnt&&(i*prime[j]<=n||i*prime[j]<=m);j++) {
            f[i*prime[j]]=1;
            if(i%prime[j]==0) {
                phi[i*prime[j]]=phi[i]*prime[j];
                break;
            }
            else phi[i*prime[j]]=phi[i]*(prime[j]-1);
        }
    }

}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    if(x<0) {x=-x;putchar('-');}
    ll y=10,len=1;
    while(y<=x) {y*=10;len++;}
    while(len--) {y/=10;putchar(x/y+48);x%=y;}
}

int main() {

    n=read();m=read();

    init();

    ans=(((n+1)*(m+1))*((n+1)*(m+1)-1)*((n+1)*(m+1)-2))/6;

    ans=ans-(n*(n+1)*(n-1))/6*(m+1);

    ans=ans-(m*(m+1)*(m-1))/6*(n+1);

    for(ll d=2;d<=n&&d<=m;d++) {
        ans-=phi[d]*(n-d+n%d+2)*(n/d)*(m-d+m%d+2)*(m/d)/2;
    }

    write(ans);

    return 0;
}

代码(O(nm)):

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

ll ans,n,m;

ll gcd(ll a,ll b) {
    if(b==0) return a;
    return gcd(b,a%b);
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    if(x<0) {x=-x;putchar('-');}
    ll y=10,len=1;
    while(y<=x) {y*=10;len++;}
    while(len--) {y/=10;putchar(x/y+48);x%=y;}
}

int main() {

    n=read();m=read();

    ans=(((n+1)*(m+1))*((n+1)*(m+1)-1)*((n+1)*(m+1)-2))/6;

    ans=ans-(n*(n+1)*(n-1))/6*(m+1);

    ans=ans-(m*(m+1)*(m-1))/6*(n+1);

    for(ll i=1;i<=n;i++) {
        for(ll j=1;j<=m;j++) {
            ans=ans-2*(gcd(i,j)-1)*(n-i+1)*(m-j+1);
        }
    }

    write(ans);

    return 0;
}