题解 P5175 【数列】

· · 题解

n\leq 10^{18}

这摆明要用矩阵……

ans=\sum_{i=1}^na_i^2 a_n=x\cdot a_{n-1}+y\cdot a_{n-2}\Rightarrow a_n^2=x^2a_{n-1}^2+2xy\cdot a_{n-1}a_{n-2}+y^2a_{n-2}^2

所以我们可以知道

\color{red}{Once\;we\;kown\;the\;number\;of\;a_{n-1}\;,\;a_{n-2}\;and\;a_{n-1}\cdot a_{n-2},} \color{green}{we\;can\;get\;the\;number\;of\;a_{n}}

所以我们自然而然地想到构造矩阵

设S(n)=ans(n);

\begin{bmatrix}S_n\\a_{n+1}^2\\a_{n+1}\cdot a_{n}\\a_{n}^2\end{bmatrix} = \begin{bmatrix} 1&\;1&\;0&\;0 \\ 0 &\;x^2&\;2\cdot xy&\;y^2 \\ 0&\;x&\;y&\;0 \\ 0&\;1&\;0&\;0 \end{bmatrix} \times \begin{bmatrix}S_{n-1}\\a_{n}^2\\a_{n}\cdot a_{n-1}\\a_{n-1}^2\end{bmatrix}

然后我们最初的矩阵是

\begin{bmatrix}S_{1}\\a_{2}^2\\a_{2}\cdot a_{1}\\a_{1}^2\end{bmatrix}

所以答案就是初始矩阵乘以n-1次构造矩阵

#include<cstdio>
#define ll long long
const ll mod = 1e9 + 7;

struct matrix {
    ll va[5][5];
    int line, cross;
    void Mem() {
        for (int i = 1; i <= 4; i++) {
            for (int j = 1; j <= 4; j++) {
                va[i][j] = 0;
            }
        }
        return ;
    }
};

matrix operator *(const matrix &a, const matrix &b) {
    matrix c;
    c.Mem();
    for (int i = 1; i <= a.line; i++) {
        for (int k = 1; k <= a.cross; k++) {
            for (int j = 1; j <= b.cross; j++) {
                c.va[i][j] = (c.va[i][j] + a.va[i][k] * b.va[k][j] % mod) % mod;
            }
        }
    }
    c.line = a.line;
    c.cross = b.cross;
    return c;
}

int Starseven(void) {
    int t;
    read(t);
    matrix ans, txt;
    while(t--) {
        ll n, a1, a2, x, y;
        read(n);
        read(a1);
        read(a2);
        read(x);
        read(y);
        if(n == 1) {
            ll ott = a1 * a1 % mod; 
            write(ott);
            puts("");
            continue;
        }
        else if(n == 2) {
            ll ott = a1 * a1 % mod + a2 * a2 % mod;
            ott %= mod;
            write(ott);
            puts("");
            continue;
        }
        ans.Mem();
        txt.Mem();

        ans.line = 4;
        ans.cross = 1;
        ans.va[1][1] = a1 * a1 % mod;
        ans.va[2][1] = a2 * a2 % mod;
        ans.va[3][1] = a2 * a1 % mod;
        ans.va[4][1] = a1 * a1 % mod;

        txt.line = 4;
        txt.cross = 4;
        txt.va[1][1] = 1ll;
        txt.va[1][2] = 1ll;
        txt.va[2][2] = x * x % mod;
        txt.va[2][3] = 2ll * x % mod * y % mod;
        txt.va[2][4] = y * y % mod;
        txt.va[3][2] = x;
        txt.va[3][3] = y;
        txt.va[4][2] = 1ll;

        n -= 1ll;
        while(n) {
            if(n & 1ll) ans = txt * ans;
            n >>= 1ll;
            txt = txt * txt;
        }
        write(ans.va[1][1]);
        puts("");
    }
    return 0;   
}