[数学]秦九韶算法

Nickel_Angel

2019-06-17 22:14:47

Personal

这里是秦九韶$(sh\acute{a}o)$算法的学习笔记$qwq$ 这个算法主要是快速对一个一元$n$次多项式进行求值,一个一元$n$次多项式形如: $$f(x) = \sum_{i = 0}^{n}a_ix^i$$ 即: $$f(x) = a_nx^n + a_{n-1}x^{n-1}+\cdots+a_1x+a_0 $$ 如果朴素算法则需要求出$x^2,x^3,\cdots,x^n$的值,代码写起来就像这样: ~~(经询问和实验,这个算法的复杂度与下方秦九韶算法的复杂度相等……)~~ ```cpp int f(int x, int *a, int n) { int res = 0, tmp = 1; for (int i = 0; i <= n; ++i) { res += a[i] * tmp; tmp *= x; } return res; } ``` 可以考虑将其改写为如下形式: $$\begin{aligned}f(x)&=a_nx^n + a_{n-1}x^{n-1}+\cdots+a_1x+a_{0}\\&=(a_nx^{n-1}+a_{n-1}x^{n-2}+\cdots+a_2x+a_1)x+a_0\\&=((a_nx^{n-2}+a_{n-1}x^{n-3}+\cdots+a_3x+a_2)x+a_1)x+a_0\\&\ \ \vdots\\&=(\cdots((a_nx+a_{n-1})x+a_{n-2})x+\cdots+a_1)x+a_0\end{aligned}$$ 计算时,首先计算最内层括号内一次多项式的值,即: $$\begin{aligned}v_0&=a_n\\v_1&=a_nx+a_{n-1}\end{aligned}$$ 然后由内向外逐层计算一次多项式的值,即: $$\begin{aligned}v_2&=v_1x+a_{n-2}\\v_3&=v_2x+a_{n-3}\\&\ \ \vdots\\v_n&=v_{n-1}x+a_0\end{aligned}$$ 最终$v_n$即为所求,所以只要维护一个变量,不断迭代即可…… 这样计算只用$n$次加法,$n$次乘法就可以求出一个一元$n$次多项式的值…… Code: ```cpp int f(int x, int *a, int n) { int res = 0; for (int i = n; i >= 0; --i) res = res * x + a[i]; return res; } ``` 例题: [P5148 大循环](https://www.luogu.org/problemnew/show/P5148) 读题,发现是需要求: $$\sum_{a_1=1}^{n}\sum_{a_2=1}^{a_1 - 1}\sum_{a_3=1}^{a_2 - 1}\cdots\sum_{a_k=1}^{a_{k-1}-1} f(q)$$ 可以将$f(q)$提出并将后面的和式进行等价转换,得: $$f(q) \times \sum_{a_1=1}^{n}\sum_{a_2=1}^{n}\sum_{a_3=1}^{n}\cdots\sum_{a_k=1}^{n}[n \geq a_1 > a_2 > \cdots > a_k \geq 1]$$ (这里由于原式的第一个和式的上界为$n$,且从第二个和式起,后边的第$i(2 \leq i \leq k)$个和式下界均为$1$,上界均为$a_{i-1}-1$,即满足$\forall j\in[1,k],1 \leq a_j \leq n$且$a_i < a_{i-1} (2 \leq i \leq k)$故等价于$n \geq a_1 > a_2 > \cdots > a_k \geq 1$的约束条件) $f(q)$可利用秦九韶公式求得,考虑后面的式子的意义: 不难发现,$a_1,a_2,\cdots,a_k$的取值均为$[1,n]$,而我们需要使$a_1,a_2,\cdots,a_k$严格递减,所以这个和式等价于在$[1,n]$中从大到小依次取$k$个互不相同的数方案数,不难发现,根据组合数的定义,这个方案数就是$C(n,k)$,故答案: $$ans=f(q) \times C(n,k)$$ 直接套公式求值即可。 Code: ```cpp #include <cstdio> #include <cstring> #include <algorithm> #include <iostream> using namespace std; const int p = 1e9 + 7; int n, m, k, A[500010]; long long q; inline int power(int a, int b) { int res = 1; while (b) { if (b & 1) res = 1ll * res * a % p; a = 1ll * a * a % p; b >>= 1; } return res; } inline int C(int a, int b) { if ((b << 1) > n) b = a - b;//利用组合数的互反率:C(n,k)=C(n,n-k),进行常数优化 int fac = 1, inv_fac = 1; for (int i = a - b + 1; i <= a; ++i) fac = 1ll * fac * i % p; for (int i = 1; i <= b; ++i) inv_fac = 1ll * inv_fac * i % p; inv_fac = power(inv_fac, p - 2); return (int)(1ll * inv_fac * fac % p); } inline int f(int x, int *a) { int res = 0; for (int i = m; i >= 0; --i) res = (1ll * res * x + a[i]) % p; return res; } int main() { scanf("%d%d%d%lld", &n, &m, &k, &q); for (int i = 0; i <= m; ++i) scanf("%d", A + i); printf("%d", (int)(1ll * f(q % p, A) * C(n, k) % p)); return 0; } ```