拉格朗日插值法

· · 个人记录

$\ \ \ \ \ \ \ $已知在二维平面上,一多项式$f$的函数图像经过 $n+1$ 个点 $(x_i,y_i)$ ,既我们知道 $f(x_i)=y_i$,求$f(k)$的值。 - ### 初步分析 $\ \ \ \ \ \ \ $显然,经过 $n+1$ 个点的多项式函数一定是$n$次函数,那么也很显然的,我们可以列出一个 $n$ 元方程组来解决这个问题,形如: $$\begin{cases}y_1=a_1\cdot x_1^n+a_2\cdot x_1^{n-1}+a_3\cdot x_1^{n-2}+...+a_{n+1}\\y_2=a_1\cdot x_2^n+a_2\cdot x_2^{n-1}+a_3\cdot x_2^{n-2}+...+a_{n+1}\\\ \ \ \ \ \ \ \ \ \ \ \ \ ...\\y_n=a_1\cdot x_n^n+a_2\cdot x_n^{n-1}+a_3\cdot x_n^{n-2}+...+a_{n+1}\end{cases}$$ $\ \ \ \ \ \ \ $直观的,我们可以高斯消元来做,复杂度$O(n^3)$,复杂度多半已上天,那么我们如何快速处理呢? - ### 插值法 $\ \ \ \ \ \ \ $我们现在知道的是 $f(x_i)=y_i$,那么我们想怎么快速表达出这个多项式,我们设定一种函数 $S_i$: $$S_i(x)=[x=x_i]$$ $\ \ \ \ \ \ \ $它的意义是只有当 $x$ 为 $x_i$ 函数值才为 $1$,其他为 $0$,那么显然有: $$f=\sum_{i=1}^n y_i\cdot S_i$$ $\ \ \ \ \ \ \ $很显然的,$S_i$ 是个 $n$ 次多项式,而我们的答案也就是: $$f(k)=\sum_{i=1}^n y_i\cdot S_i(k)$$ $\ \ \ \ \ \ \ $遗憾地告诉你,光靠这个办法是不能还原$f$的函数图像的,只能得到近似图像,所以插值法是有一定误差的。 $\ \ \ \ \ \ \ $现在我们简化了问题,如何求$S_i$ 是个 $n$ 次多项式。 $\ \ \ \ \ \ \ $其实满足$S_i(x)$的函数取值合法挺简单的,主要问题是如何满足他是 $n$ 次多项式的事实,那么我们就先把它化作下面的形态: $$S_i(k)=\prod_{j=1,j\neq \rm something}^{n+1}a_jk+b_j$$ $\ \ \ \ \ \ \ $如何就卡住了,我们不知道怎么办,但是还记得吗,我们只能得到近似图像,所以我们只需要满足$S_i(x)$的图像大约在$[x=x_i]$就行了,拉格朗日给了一种解法: $$S_i(k)=\prod_{j=1,j\neq i}^{n+1}\frac{k-x_j}{x_i-x_j}$$ $\ \ \ \ \ \ \ $具体证明可以看这里:[【 VictoryCzt_拉格朗日插值法学习笔记】](https://blog.csdn.net/VictoryCzt/article/details/82933843) $\ \ \ \ \ \ \ $所以得到拉格朗日公式($n$个点),复杂度是$O(n^2)$的: $$f(k)=\sum_{i=1}^{n} y_i\prod_{j=1,j\neq i}^{n}\frac{k-x_j}{x_i-x_j}$$ - ## 代码实现: $\ \ \ \ \ \ \ $普通的拉格朗日差值法其实代码很简单了: ```cpp double Lagrange(double *x,double *y,double n,double k){ double top,bot,ret=0; for(double i=1;i<=n;i++){ top=1.0;bot=1.0; for(double j=1;j<=n;++j)if(i!=j) top*=k-x[j], bot*=x[i]-x[j]; ret+=y[i]*top/bot; } return ret; } ``` $\ \ \ \ \ \ \ $有些时候呢,我们的 $x_i$ 是连续的,所以说公式变形为: $$f(k)=\sum_{i=1}^{n} y_i\prod_{j=1,j\neq i}^{n}\frac{k-j}{i-j}$$ $\ \ \ \ \ \ \ $我们令: $$pre_i=\prod_{j=1}^{i}(k-j)=pre_{i-1}\times(k-j)$$ $$suf_i=\prod_{j=i}^{n}(k-j)=suf_{i+1}\times(k-j)$$ $\ \ \ \ \ \ \ $显然对于这两个函数是可以$O(n)$预处理的。 $\ \ \ \ \ \ \ $原公式可以化成: $$f(k)=\sum_{i=1}^{n} y_i \frac{pre_{i-1}\times suf_{i+1}}{(i-1)!\times(n-i)!}$$ $\ \ \ \ \ \ \ $注意分母是有符号需要判断的,$n-i$为奇数时,分母为负。 $\ \ \ \ \ \ \ $显然阶乘也可以预处理,于是乎复杂度变成了$O(n)$。 ```cpp long long pre[N],suf[N],fac[N]; long long Lagrange(long long *y,long long n,long long k){ long long top,bot,ret=0; pre[0]=suf[n+1]=fac[0]=1; for(int i=1;i<=n;i++)pre[i]=pre[i-1]*(k-i); for(int i=n;i>=1;i--)suf[i]=suf[i+1]*(k-i); for(int i=2;i<=n;i++)fac=fac*i; for(int i=1;i<=n;i++){ top*=pre[i-1]*suf[i+1], bot*=fac[i-1]*fac[n-i]; if((n-i)&1) ret-=y[i]*top/bot; else ret+=y[i]*top/bot; } return ret; } ``` - ## 例题[【Codeforces Round #492 F. Cowmpany Cowmpensation】](https://codeforces.com/contest/995/problem/F) $\ \ \ \ \ \ \ $题目大意: $\ \ \ \ \ \ \ $给你一棵 $n$ 个节点的树,以 $1$ 为根节点,现在让你给每个节点分配一个权值$∈[1,D]$,使得每个节点的权值不超过他的父亲节点( $1$ 号节点除外),问一共有多少种分配方式。$(1≤n≤3000, 1≤D≤10^9)$。 $\ \ \ \ \ \ \ $显然直观的有一个二维树形dp的做法,一维表示当前节点,一维表示这个节点的取值情况,来描述方案数: $$f_{i,j}=\prod_{s\in Son_i}f_{s,j}+f_{i,j-1}$$ $\ \ \ \ \ \ \ $答案就是$f_{1,D}$了。 $\ \ \ \ \ \ \ $复杂度为 $O(nD)$ 的,过不了,但是我们观察dp式子,可以发现……这个家伙是个多项式吧?继续大胆猜测,这个是个 $n$ 维多项式。 $\ \ \ \ \ \ \ $所以我们用dp来预处理出 $f_{1,i}\ ,\ {i\in[1,n+1]}$,然后使用拉格朗日差值法,便可以求出 $f_{1,D}$ 了,复杂度是 $O(n^2)$ 的。 $\ \ \ \ \ \ \ $代码如下,~~取模真的取死人啊~~ ```cpp #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<cctype> #include<cstdio> #include<vector> #include<string> #include<queue> #include<stack> #include<cmath> #include<map> #include<set> using namespace std; const int inf=0x7fffffff; const double eps=1e-10; const double pi=acos(-1.0); //char buf[1<<15],*S=buf,*T=buf; //char getch(){return S==T&&(T=(S=buf)+fread(buf,1,1<<15,stdin),S==T)?0:*S++;} inline int read(){ int x=0,f=1;char ch;ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') f=0;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch&15);ch=getchar();} if(f)return x;else return -x; } const int N=3030; const long long mod=1e9+7; long long power(long long a,long long b){ long long ans=1ll; for(;b;b>>=1,a=a*a%mod)if(b&1ll)ans=ans*a%mod; return ans; } int n; long long D; vector<int> G[N]; long long f[N][N]; void dfs(int u){ for(int i=1;i<=n+1;i++)f[u][i]=1; for(int i=0;i<G[u].size();i++){ dfs(G[u][i]); for(int j=1;j<=n+1;j++) f[u][j]=(f[u][j]*f[G[u][i]][j])%mod; } for(int x=1;x<=n+1;x++)f[u][x]=(f[u][x]+f[u][x-1])%mod; } long long pre[N],suf[N],fac[N]; long long Lagrange(long long *y,long long n,long long k){ long long top,bot,ret=0,Fac; pre[0]=suf[n+1]=Fac=1ll; for(long long i=1;i<=n;i++)pre[i]=(pre[i-1]*(k-i+mod)%mod)%mod; for(long long i=n;i>=1;i--)suf[i]=(suf[i+1]*(k-i+mod)%mod)%mod; for(long long i=2;i<=n;i++)Fac=Fac*i%mod; fac[n]=power(Fac,mod-2); for(int i=n-1;i>=0;i--)fac[i]=fac[i+1]*1ll*(i+1)%mod; for(int i=1;i<=n;i++){ top=(pre[i-1]*suf[i+1]%mod)%mod, bot=(fac[i-1]*fac[n-i]%mod)%mod; if((n-i)&1)ret=(ret-(y[i]*top%mod*bot%mod)+mod)%mod; else ret=(ret+(y[i]*top%mod*bot%mod))%mod; } return (ret+mod)%mod; } int main() { n=read();scanf("%I64d",&D);D%=mod; for(int i=2,a;i<=n;i++) a=read(),G[a].push_back(i); dfs(1); if(D<=1ll*n+1ll)printf("%I64d\n",f[1][D]); else printf("%I64d\n",Lagrange(f[1],n+1,D)); return 0; } ```