题解

· · 个人记录

题目

题目(强化版)

Subtask1  15分

观察到测试数据一组,n\in [1,10^7]O(tn) 即可通过,记忆化搜索/线性 dp 即可通过

#include <bits/stdc++.h>
#define maxn 10000005
#define mod 998244353
using namespace std;
int t,n,F[maxn],ans;
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    cin>>t;
    while(t--){
        cin>>n;
        F[1]=F[2]=1;ans=0;
        for(int i=2;i<=n;i++)F[i]=(F[i-1]+F[i-2])%mod;
        for(int i=1;i<=n;i++)(ans+=F[i])%=mod;
        cout<<ans<<'\n';
    }
}

Subtask6 100分

观察 Fibonacci 数列的递推式 F(n)=F(n-1)+F(n-2) ,可以发现这是一个线性递推,显然有转移矩阵 \binom{0,1}{1,1} ,进而可以利用矩阵快速幂在O(logn) 时间复杂度算出第 nFibonacci 数列的数值。

然而题目要求求出前 nFibonacci 数列之和,仅仅能够求出第 n 项是不够的,因此我们想到可以多维护一个矩阵的信息,不妨设 f(n)=\sum_{i=1}^{n} F(i) ,则 \begin{pmatrix} F(n)\\ F(n-1)\\ f(n) \end{pmatrix}=\begin{pmatrix} F(n-1)+F(n-2)\\ F(n-1)\\ f(n-1)+F(n-1)+F(n-2) \end{pmatrix}=\begin{pmatrix} 1 & 1 & 0\\ 0 & 1 & 0\\ 1 & 1 & 1 \end{pmatrix}\begin{pmatrix} F(n)\\ F(n-1)\\ f(n) \end{pmatrix} ,有转移矩阵 \begin{pmatrix} 1 & 1 & 0\\ 0 & 1 & 0\\ 1 & 1 & 1 \end{pmatrix} 维护 \begin{pmatrix} F(n)\\ F(n-1)\\ f(n) \end{pmatrix} 的转移。

进而使用矩阵快速幂,即可 O(logn) 求得 f(n)

通过,最慢一个点 899ms 。(记录)

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
using namespace std;
ll base[3][3]={{0,1,0},{1,1,0},{1,1,1}},a[2][3][3],b[2][3][3],n;
int t;
bool opa,opb;
inline void c(){
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)b[opb^1][i][j]=(b[opb][0][j]*a[opa][i][0]+b[opb][1][j]*a[opa][i][1]+b[opb][2][j]*a[opa][i][2])%mod;
    opb^=1;
}
inline void c_(){
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)a[opa^1][i][j]=(a[opa][0][j]*a[opa][i][0]+a[opa][1][j]*a[opa][i][1]+a[opa][2][j]*a[opa][i][2])%mod;
    opa^=1;
}
inline void ksm(ll x){
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)b[opb][i][j]=0;
    b[opb][0][0]=b[opb][1][1]=b[opb][2][2]=1;
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)a[opa][i][j]=base[i][j];
    while(x){
        if(x&1)c();
        c_();
        x>>=1;
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    cin>>t;
    while(t--){
        cin>>n;
        if(n==0){cout<<"0\n";continue;}
        ksm(n-1);
        cout<<(b[opb][2][1]+b[opb][2][2])%mod<<'\n';
    } 
}

 优化1

由于每次的计算过程中,矩阵 a 都被进行了许多次重复计算,可以想到预处理矩阵 a ,也就是 2 的幂的转移矩阵,每次计算直接使用。

常数减小超过一半,最慢一个点 405ms ,这是因为实际运行过程中并不是 2 的每一个幂的转移矩阵都要使用,这个优化大大减小了运算次数。(记录)

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
using namespace std;
ll base[3][3]={{0,1,0},{1,1,0},{1,1,1}},a[65][3][3],b[2][3][3],n;
int t;
bool op;
inline void c(int w){
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)b[op^1][i][j]=(b[op][0][j]*a[w][i][0]+b[op][1][j]*a[w][i][1]+b[op][2][j]*a[w][i][2])%mod;
    op^=1;
}
inline void ksm(ll x){
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)b[op][i][j]=0;
    b[op][0][0]=b[op][1][1]=b[op][2][2]=1;
    int w=0; 
    while(x){
        if(x&1)c(w);
        x>>=1;
        w++;
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    for(int i=0;i<3;i++)for(int j=0;j<3;j++)a[0][i][j]=base[i][j];
    for(int i=1;i<=62;i++)for(int j=0;j<3;j++)for(int k=0;k<3;k++)a[i][j][k]=(a[i-1][0][k]*a[i-1][j][0]+a[i-1][1][k]*a[i-1][j][1]+a[i-1][2][k]*a[i-1][j][2])%mod; 
    cin>>t;
    while(t--){
        cin>>n;
        if(n==0){cout<<"0\n";continue;}
        ksm(n-1);
        cout<<(b[op][2][1]+b[op][2][2])%mod<<'\n';
    } 
}

优化2

让我们重新思考这个问题。我们可以通过使用矩阵快速幂 O(logn) 求得 Fibonacci 数列第 n 项,但是为了同时维护求和,需要将矩阵从 2 维增加到了 3 维,这一过程增加了 (3^2-2^2)/2^2=125\% 的计算量,这能否被避免?

注意到:

2*f(n) =F(1)+F(n)+\sum_{i=1}^{n-1} (F(i)+F(i+1)) =F(1)+F(n)+\sum_{i=3}^{n+1} F(i) =F(1)+F(n)+f(n+1)-F(2)-F(1) =F(n)+f(n+1)-F(2) =F(n)+f(n)+F(n+1)-F(2) =f(n)+F(n)+F(n+1)-1

所以有 f(n)=F(n)+F(n+1)-1 ,因此只需要求出 F(n)F(n+1) 即可计算前 n 项之和,这只需要 2 维矩阵维护,即: \begin{pmatrix} F(n)\\ F(n-1)\end{pmatrix}=\begin{pmatrix} F(n-1)+F(n-2)\\ F(n-1)\end{pmatrix}=\begin{pmatrix} 1 & 1\\0 & 1 \end{pmatrix}\begin{pmatrix} F(n-1)\\ F(n-2)\end{pmatrix}

又快了近一半,最慢一个点 245ms ,至于为什么没有变为原来的 4/9 ,有一部分原因用在了输入输出上。(记录)

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
using namespace std;
ll a[65][2][2],b[2][2][2],n;
int t;
bool op;
inline void c(int w){
    b[op^1][0][0]=(a[w][0][0]*b[op][0][0]+a[w][0][1]*b[op][1][0])%mod;
    b[op^1][0][1]=(a[w][0][0]*b[op][0][1]+a[w][0][1]*b[op][1][1])%mod;
    b[op^1][1][0]=(a[w][1][0]*b[op][0][0]+a[w][1][1]*b[op][1][0])%mod;
    b[op^1][1][1]=(a[w][1][0]*b[op][0][1]+a[w][1][1]*b[op][1][1])%mod;
    op^=1;
}
inline void ksm(ll x){
    b[op][0][1]=b[op][1][0]=0,b[op][0][0]=b[op][1][1]=1;
    int w=0; 
    while(x){
        if(x&1)c(w);
        x>>=1;
        w++;
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    a[0][0][0]=0,a[0][0][1]=1,a[0][1][0]=1,a[0][1][1]=1;
    for(int i=1;i<=62;i++){
        a[i][0][0]=(a[i-1][0][0]*a[i-1][0][0]+a[i-1][0][1]*a[i-1][1][0])%mod;
        a[i][0][1]=(a[i-1][0][0]*a[i-1][0][1]+a[i-1][0][1]*a[i-1][1][1])%mod;
        a[i][1][0]=(a[i-1][1][0]*a[i-1][0][0]+a[i-1][1][1]*a[i-1][1][0])%mod;
        a[i][1][1]=(a[i-1][1][0]*a[i-1][0][1]+a[i-1][1][1]*a[i-1][1][1])%mod;
    } 
    cin>>t;
    while(t--){
        cin>>n;
        if(n==0){cout<<"0\n";continue;}
        ksm(n-1);
        cout<<(b[op][0][0]+b[op][0][1]+b[op][1][0]+b[op][1][1]-1)%mod<<'\n';
    } 
}

优化3

注意到维护转移 n-1 次的矩阵的最终目的是用它左乘 \binom1 1 来求得 \begin{pmatrix}F(n)\\ F(n+1)\end{pmatrix} ,然而最终的 n-1 次的矩阵本身是没有用的,所以可以直接用 12 列矩阵初始赋值 \binom1 1 ,每次用对应 2 的幂次转移矩阵左乘维护答案即可。

最慢一个点 209ms (记录)

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
using namespace std;
ll a[65][2][2],b[2][2],n;
int t;
bool op;
inline void c(int w){
    b[op^1][0]=(a[w][0][0]*b[op][0]+a[w][0][1]*b[op][1])%mod;
    b[op^1][1]=(a[w][1][0]*b[op][0]+a[w][1][1]*b[op][1])%mod;
    op^=1;
}
inline void ksm(ll x){
    b[op][0]=b[op][1]=1;
    int w=0; 
    while(x){
        if(x&1)c(w);
        x>>=1;
        w++;
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    a[0][0][0]=0,a[0][0][1]=1,a[0][1][0]=1,a[0][1][1]=1;
    for(int i=1;i<=62;i++){
        a[i][0][0]=(a[i-1][0][0]*a[i-1][0][0]+a[i-1][0][1]*a[i-1][1][0])%mod;
        a[i][0][1]=(a[i-1][0][0]*a[i-1][0][1]+a[i-1][0][1]*a[i-1][1][1])%mod;
        a[i][1][0]=(a[i-1][1][0]*a[i-1][0][0]+a[i-1][1][1]*a[i-1][1][0])%mod;
        a[i][1][1]=(a[i-1][1][0]*a[i-1][0][1]+a[i-1][1][1]*a[i-1][1][1])%mod;
    } 
    cin>>t;
    while(t--){
        cin>>n;
        if(n==0){cout<<"0\n";continue;}
        ksm(n-1);
        cout<<(b[op][0]+b[op][1]-1)%mod<<'\n';
    } 
}

优化4

考虑优化快速幂。目前实现的快速幂实际上是每次对2取模,导致每次运算为 \log_2{10^{18}}≈60 次,如果增大模数,每次运算次数就会降低。因此考虑将模数增大为1024,此时每次运算为 \log_{1024}{10^{18}}≈6 次。而对应的运算矩阵预处理出来即可。

最慢一个点 97ms (记录)

#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
using namespace std;
ll a[7][1024][2][2],b[2][2],n;
int t;
bool op;
inline void c(int w,int w_){
    b[op^1][0]=(a[w][w_][0][0]*b[op][0]+a[w][w_][0][1]*b[op][1])%mod;
    b[op^1][1]=(a[w][w_][1][0]*b[op][0]+a[w][w_][1][1]*b[op][1])%mod;
    op^=1;
}
inline void ksm(ll x){
    b[op][0]=b[op][1]=1;
    int w=0; 
    while(x){
        c(w,x%1024);
        x>>=10;
        w++;
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    a[0][1][0][0]=0,a[0][1][0][1]=a[0][1][1][0]=a[0][1][1][1]=1;
    for(int i=0;i<=6;i++){
        a[i][0][0][0]=a[i][0][1][1]=1,a[i][0][0][1]=a[i][0][1][0]=0;
        for(int j=2;j<1024;j++){
            a[i][j][0][0]=(a[i][1][0][0]*a[i][j-1][0][0]+a[i][1][0][1]*a[i][j-1][1][0])%mod;
            a[i][j][0][1]=(a[i][1][0][0]*a[i][j-1][0][1]+a[i][1][0][1]*a[i][j-1][1][1])%mod;
            a[i][j][1][0]=(a[i][1][1][0]*a[i][j-1][0][0]+a[i][1][1][1]*a[i][j-1][1][0])%mod;
            a[i][j][1][1]=(a[i][1][1][0]*a[i][j-1][0][1]+a[i][1][1][1]*a[i][j-1][1][1])%mod;
        }
        if(i<6){
            a[i+1][1][0][0]=(a[i][512][0][0]*a[i][512][0][0]+a[i][512][0][1]*a[i][512][1][0])%mod;
            a[i+1][1][0][1]=(a[i][512][0][0]*a[i][512][0][1]+a[i][512][0][1]*a[i][512][1][1])%mod;
            a[i+1][1][1][0]=(a[i][512][1][0]*a[i][512][0][0]+a[i][512][1][1]*a[i][512][1][0])%mod;
            a[i+1][1][1][1]=(a[i][512][1][0]*a[i][512][0][1]+a[i][512][1][1]*a[i][512][1][1])%mod;
        }
    }
    cin>>t;
    while(t--){
        cin>>n;
        if(n==0){cout<<"0\n";continue;}
        ksm(n-1);
        cout<<(b[op][0]+b[op][1]-1)%mod<<'\n';
    } 
}

优化5

由于输入 5*10^5 ,每个 n18 位,快读应该可以优化很大的常数。

最慢一个点 64ms (记录)

#include<bits/stdc++.h>
#define getc (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<15,stdin),cs==ct)?0:*cs++)
#define ll long long
#define mod 998244353
using namespace std;
char cb[1<<15],*cs,*ct;
inline void read(auto &num){
    char ch;int f=1;
    while(!isdigit(ch=getc))if(ch=='-')f=-1;
    for(num=ch-'0';isdigit(ch=getc);num=num*10+ch-'0');
    num*=f;
}
ll a[7][1024][2][2],b[2][2],n;
int t;
bool op;
inline void c(int w,int w_){
    b[op^1][0]=(a[w][w_][0][0]*b[op][0]+a[w][w_][0][1]*b[op][1])%mod;
    b[op^1][1]=(a[w][w_][1][0]*b[op][0]+a[w][w_][1][1]*b[op][1])%mod;
    op^=1;
}
inline void ksm(ll x){
    b[op][0]=b[op][1]=1;
    int w=0; 
    while(x){
        c(w,x%1024);
        x>>=10;
        w++;
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    a[0][1][0][0]=0,a[0][1][0][1]=a[0][1][1][0]=a[0][1][1][1]=1;
    for(int i=0;i<=6;i++){
        a[i][0][0][0]=a[i][0][1][1]=1,a[i][0][0][1]=a[i][0][1][0]=0;
        for(int j=2;j<1024;j++){
            a[i][j][0][0]=(a[i][1][0][0]*a[i][j-1][0][0]+a[i][1][0][1]*a[i][j-1][1][0])%mod;
            a[i][j][0][1]=(a[i][1][0][0]*a[i][j-1][0][1]+a[i][1][0][1]*a[i][j-1][1][1])%mod;
            a[i][j][1][0]=(a[i][1][1][0]*a[i][j-1][0][0]+a[i][1][1][1]*a[i][j-1][1][0])%mod;
            a[i][j][1][1]=(a[i][1][1][0]*a[i][j-1][0][1]+a[i][1][1][1]*a[i][j-1][1][1])%mod;
        }
        if(i<6){
            a[i+1][1][0][0]=(a[i][512][0][0]*a[i][512][0][0]+a[i][512][0][1]*a[i][512][1][0])%mod;
            a[i+1][1][0][1]=(a[i][512][0][0]*a[i][512][0][1]+a[i][512][0][1]*a[i][512][1][1])%mod;
            a[i+1][1][1][0]=(a[i][512][1][0]*a[i][512][0][0]+a[i][512][1][1]*a[i][512][1][0])%mod;
            a[i+1][1][1][1]=(a[i][512][1][0]*a[i][512][0][1]+a[i][512][1][1]*a[i][512][1][1])%mod;
        }
    }
    read(t);
    while(t--){
        read(n);
        if(n==0){cout<<"0\n";continue;}
        ksm(n-1);
        cout<<(b[op][0]+b[op][1]-1)%mod<<'\n';
    } 
}