Q(a,b,c)=\sum_{i=1}^{n}\bigl(y_i - a - b x_i - c x_i^2\bigr)^2.
类似地,采用配方法:
先对 a 配方:固定 b,c,把 Q 写成 \sum\bigl[(y_i - b x_i - c x_i^2) - a\bigr]^2,得到
a = \frac{\sum (y_i - b x_i - c x_i^2)}{n} = \bar{y} - b\bar{x} - c\,\overline{x^2}, \tag{3}
其中 \overline{x^2} = \frac{1}{n}\sum x_i^2。
将 (3) 代回 Q,得到仅含 b,c 的二元二次函数。
通过再次对 b 和 c 配方(先固定 c 对 b 配方,再代入后对 c 配方),可以得到三个方程:
\begin{cases}
n a + (\sum x) b + (\sum x^2) c = \sum y \\
(\sum x) a + (\sum x^2) b + (\sum x^3) c = \sum xy \\
(\sum x^2) a + (\sum x^3) b + (\sum x^4) c = \sum x^2 y
\end{cases}
这样我们就成功仅用初中知识,求出了二次回归的正规方程组。
2.3 使用克拉默法则求解
上述三元一次方程组可用克拉默法则直接求解。系数行列式 D 及分子行列式 D_a,D_b,D_c 分别为:
D = \begin{vmatrix}
n & \sum x & \sum x^2 \\
\sum x & \sum x^2 & \sum x^3 \\
\sum x^2 & \sum x^3 & \sum x^4
\end{vmatrix},\quad
D_a = \begin{vmatrix}
\sum y & \sum x & \sum x^2 \\
\sum xy & \sum x^2 & \sum x^3 \\
\sum x^2 y & \sum x^3 & \sum x^4
\end{vmatrix},
D_b = \begin{vmatrix}
n & \sum y & \sum x^2 \\
\sum x & \sum xy & \sum x^3 \\
\sum x^2 & \sum x^2 y & \sum x^4
\end{vmatrix},\quad
D_c = \begin{vmatrix}
n & \sum x & \sum y \\
\sum x & \sum x^2 & \sum xy \\
\sum x^2 & \sum x^3 & \sum x^2 y
\end{vmatrix}.
则 a = D_a/D,\; b = D_b/D,\; c = D_c/D。代码中的 d,da,dB,dc 正是这些行列式按第一行展开后的代数表达式。
3. 平方根回归 y = a + b\sqrt{x}
令 u = \sqrt{x},对 u 与 y 进行线性回归即可(公式同线性回归)。
4. 决定系数 R^2
用于评价拟合优度,其计算公式为:
R^2 = 1 - \frac{\sum_{i=1}^{n} (y_i - \hat{y}_i)^2}{\sum_{i=1}^{n} (y_i - \bar{y})^2},
## 附上线性回归拟合代码
```cpp line-numbers
#include<bits/stdc++.h>
//Things will always be better, just let time play
using namespace std;
#define HACKILLER_FAST ios::sync_with_stdio(false),cin.tie(0)
//#define int long long
#define rint register int
#define db double
#define pii pair<int,int>
#define vc vector
#define pb push_back
#define ls x<<1
#define rs x<<1|1
struct node{
db xi,yi;
}a[100050];
int n;
db la,lb,lr2;
db qa,qb,qc,qr2;
db ea,eb,er2,ln[100050]; //指数函数还未推出公式
db sa,sb,sr2;
inline void solve_linear(){
db sx = 0,sy = 0,sx2 = 0,sxy = 0;
for(int i = 0;i < n;++i){
sx+=a[i].xi,sy+=a[i].yi;
sx2+=a[i].xi*a[i].xi,sxy+=a[i].xi*a[i].yi;
}
db xb = sx/n,yb = sy/n;
lb = (sxy-n*xb*yb)/(sx2-n*xb*xb);
la = yb-lb*xb;
db tot = 0,res = 0;
for(int i = 0;i < n;++i){
res+=(a[i].yi-(la+lb*a[i].xi))*(a[i].yi-(la+lb*a[i].xi));
tot+=(a[i].yi-yb)*(a[i].yi-yb);
}
lr2 = 1.0-res/tot;
}
inline void solve_quadratic(){
db sx = 0,sx2 = 0,sx3 = 0,sx4 = 0,sy = 0,sxy = 0,sx2y = 0;
for(int i = 0;i < n;++i){
sx+=a[i].xi,sx2+=a[i].xi*a[i].xi,sx3+=a[i].xi*a[i].xi*a[i].xi,sx4+=a[i].xi*a[i].xi*a[i].xi*a[i].xi;
sy+=a[i].yi,sxy+=a[i].xi*a[i].yi,sx2y+=a[i].xi*a[i].xi*a[i].yi;
}
db d = n*(sx2*sx4-sx3*sx3)-sx*(sx*sx4-sx3*sx2)+sx2*(sx*sx3-sx2*sx2);
db da = sy*(sx2*sx4-sx3*sx3)-sx*(sxy*sx4-sx3*sx2y)+sx2*(sxy*sx3-sx2*sx2y);
db dB = n*(sxy*sx4-sx3*sx2y)-sy*(sx*sx4-sx3*sx2)+sx2*(sx*sx2y-sxy*sx2);
db dc = n*(sx2*sx2y-sxy*sx3)-sx*(sx*sx2y-sxy*sx2)+sy*(sx*sx3-sx2*sx2);
qa = da/d,qb = dB/d,qc = dc/d;
db yb = sy/n,tot = 0,res = 0;
for(int i = 0;i < n;++i){
res+=(a[i].yi-(qa+qb*a[i].xi+qc*a[i].xi*a[i].xi))*(a[i].yi-(qa+qb*a[i].xi+qc*a[i].xi*a[i].xi));
tot+=(a[i].yi-yb)*(a[i].yi-yb);
}
qr2 = 1.0-res/tot;
}
inline void solve_sqrt(){
db sx = 0,sy = 0,sx2 = 0,sxy = 0;
for(int i = 0;i < n;++i){
sx+=sqrt(a[i].xi),sy+=a[i].yi;
sx2+=a[i].xi,(sxy+=sqrt(a[i].xi)*a[i].yi);
}
db xb = sx/n,yb = sy/n;
sb = (sxy-n*xb*yb)/(sx2-n*xb*xb);
sa = yb-sb*xb;
db tot = 0,res = 0;
for(int i = 0;i < n;++i){
res+=(a[i].yi-(sa+sb*sqrt(a[i].xi)))*(a[i].yi-(sa+sb*sqrt(a[i].xi)));
tot+=(a[i].yi-yb)*(a[i].yi-yb);
}
sr2 = 1.0-res/tot;
}
signed main(){
freopen("data.in","r",stdin);
// freopen("task.out","w",stdout);
HACKILLER_FAST;
cin>>n;
for(int i = 0;i < n;++i) cin>>a[i].xi>>a[i].yi;
solve_linear();
solve_quadratic();
solve_sqrt();
cout<<fixed<<setprecision(6);
cout<<"Linear model: y = "<<la;
if(lb >= 0) cout<<" + "<<lb<<" * x";
else cout<<" - "<<-lb<<" * x";
cout<<", R^2 = "<<lr2<<'\n';
cout<<"Quadratic model: y = "<<qa;
if(qb >= 0) cout<<" + "<<qb<<" * x";
else cout<<" - "<<-qb<<" * x";
if(qc >= 0) cout<<" + "<<qc<<" * x^2";
else cout<<" - "<<-qc<<" * x^2";
cout<<", R^2 = "<<qr2<<'\n';
cout<<"Square-root model: y = "<<sa;
if(sb >= 0) cout<<" + "<<sb<<" * sqrt(x)";
else cout<<" - "<<-sb<<" * sqrt(x)";
cout<<", R^2 = "<<sr2<<'\n';
return 0;
}
//No more; No less^_^