浅谈仅用初中知识就能完成的线性回归拟合

· · 算法·理论

闲话

感谢我校数学组布置的神人作业,导致我调代码调了 3 小时。
本文将严格采取洛谷规范,浅谈仅用初中知识就能完成的线性回归拟合。
本人数学较差,若有错误,欢迎指正(私信评论皆可)。

正文

本代码实现了三种回归模型的拟合,均基于最小二乘法原理,求解使残差平方和最小的系数。

1. 线性回归 y = a + bx

计算均值 \bar{x}=\frac{1}{n}\sum x_i\bar{y}=\frac{1}{n}\sum y_i,则

b = \frac{\sum x_i y_i - n\bar{x}\bar{y}}{\sum x_i^2 - n\bar{x}^2},\quad a = \bar{y} - b\bar{x}.

2. 二次回归 y = a + bx + cx^2

2.1 配方法推导正规方程组(线性回归为例)

线性模型:y = a + bx,残差平方和为

Q(a,b)=\sum_{i=1}^{n}\bigl(y_i - a - b x_i\bigr)^2.

第一步:固定 b,对 a 配方
Q 看成 a 的二次函数:

Q = \sum_{i=1}^{n}\bigl[(y_i - b x_i) - a\bigr]^2 = \sum_{i=1}^{n}(y_i - b x_i)^2 - 2a\sum_{i=1}^{n}(y_i - b x_i) + n a^2.

这是一个关于 a 的二次式 n a^2 - 2\bigl(\sum (y_i-b x_i)\bigr)a + \text{常数}
二次函数 n a^2 + p a + q 的顶点在 a = -\frac{p}{2n},此处 p = -2\sum (y_i-b x_i),故

a = \frac{2\sum (y_i-b x_i)}{2n} = \frac{\sum (y_i-b x_i)}{n} = \bar{y} - b\bar{x}. \tag{1}

第二步:将 (1) 代回 Q,再对 b 配方
a = \bar{y} - b\bar{x} 代入 Q

\begin{aligned} Q &= \sum \bigl[y_i - (\bar{y} - b\bar{x}) - b x_i\bigr]^2 \\ &= \sum \bigl[(y_i-\bar{y}) - b(x_i-\bar{x})\bigr]^2. \end{aligned}

u_i = y_i-\bar{y}v_i = x_i-\bar{x},则

Q = \sum (u_i - b v_i)^2 = \sum (v_i^2 b^2 - 2 u_i v_i b + u_i^2) = \bigl(\sum v_i^2\bigr) b^2 - 2\bigl(\sum u_i v_i\bigr) b + \sum u_i^2.

这是关于 b 的二次函数,系数 \sum v_i^2 > 0(除非所有 x_i 相等)。顶点在

b = \frac{2\sum u_i v_i}{2\sum v_i^2} = \frac{\sum u_i v_i}{\sum v_i^2} = \frac{\sum (x_i-\bar{x})(y_i-\bar{y})}{\sum (x_i-\bar{x})^2}. \tag{2}

2.2 推广到二次回归

对于二次模型 y = a + b x + c x^2,残差平方和为

Q(a,b,c)=\sum_{i=1}^{n}\bigl(y_i - a - b x_i - c x_i^2\bigr)^2.

类似地,采用配方法:

  1. 先对 a 配方:固定 b,c,把 Q 写成 \sum\bigl[(y_i - b x_i - c x_i^2) - a\bigr]^2,得到

    a = \frac{\sum (y_i - b x_i - c x_i^2)}{n} = \bar{y} - b\bar{x} - c\,\overline{x^2}, \tag{3}

    其中 \overline{x^2} = \frac{1}{n}\sum x_i^2

  2. 将 (3) 代回 Q,得到仅含 b,c 的二元二次函数。
    通过再次对 bc 配方(先固定 cb 配方,再代入后对 c 配方),可以得到三个方程:

    \begin{cases} n a + (\sum x) b + (\sum x^2) c = \sum y \\ (\sum x) a + (\sum x^2) b + (\sum x^3) c = \sum xy \\ (\sum x^2) a + (\sum x^3) b + (\sum x^4) c = \sum x^2 y \end{cases}

    这样我们就成功仅用初中知识,求出了二次回归的正规方程组。

2.3 使用克拉默法则求解

上述三元一次方程组可用克拉默法则直接求解。系数行列式 D 及分子行列式 D_a,D_b,D_c 分别为:

D = \begin{vmatrix} n & \sum x & \sum x^2 \\ \sum x & \sum x^2 & \sum x^3 \\ \sum x^2 & \sum x^3 & \sum x^4 \end{vmatrix},\quad D_a = \begin{vmatrix} \sum y & \sum x & \sum x^2 \\ \sum xy & \sum x^2 & \sum x^3 \\ \sum x^2 y & \sum x^3 & \sum x^4 \end{vmatrix}, D_b = \begin{vmatrix} n & \sum y & \sum x^2 \\ \sum x & \sum xy & \sum x^3 \\ \sum x^2 & \sum x^2 y & \sum x^4 \end{vmatrix},\quad D_c = \begin{vmatrix} n & \sum x & \sum y \\ \sum x & \sum x^2 & \sum xy \\ \sum x^2 & \sum x^3 & \sum x^2 y \end{vmatrix}.

a = D_a/D,\; b = D_b/D,\; c = D_c/D。代码中的 d,da,dB,dc 正是这些行列式按第一行展开后的代数表达式。

3. 平方根回归 y = a + b\sqrt{x}

u = \sqrt{x},对 uy 进行线性回归即可(公式同线性回归)。

4. 决定系数 R^2

用于评价拟合优度,其计算公式为:

R^2 = 1 - \frac{\sum_{i=1}^{n} (y_i - \hat{y}_i)^2}{\sum_{i=1}^{n} (y_i - \bar{y})^2}, ## 附上线性回归拟合代码 ```cpp line-numbers #include<bits/stdc++.h> //Things will always be better, just let time play using namespace std; #define HACKILLER_FAST ios::sync_with_stdio(false),cin.tie(0) //#define int long long #define rint register int #define db double #define pii pair<int,int> #define vc vector #define pb push_back #define ls x<<1 #define rs x<<1|1 struct node{ db xi,yi; }a[100050]; int n; db la,lb,lr2; db qa,qb,qc,qr2; db ea,eb,er2,ln[100050]; //指数函数还未推出公式 db sa,sb,sr2; inline void solve_linear(){ db sx = 0,sy = 0,sx2 = 0,sxy = 0; for(int i = 0;i < n;++i){ sx+=a[i].xi,sy+=a[i].yi; sx2+=a[i].xi*a[i].xi,sxy+=a[i].xi*a[i].yi; } db xb = sx/n,yb = sy/n; lb = (sxy-n*xb*yb)/(sx2-n*xb*xb); la = yb-lb*xb; db tot = 0,res = 0; for(int i = 0;i < n;++i){ res+=(a[i].yi-(la+lb*a[i].xi))*(a[i].yi-(la+lb*a[i].xi)); tot+=(a[i].yi-yb)*(a[i].yi-yb); } lr2 = 1.0-res/tot; } inline void solve_quadratic(){ db sx = 0,sx2 = 0,sx3 = 0,sx4 = 0,sy = 0,sxy = 0,sx2y = 0; for(int i = 0;i < n;++i){ sx+=a[i].xi,sx2+=a[i].xi*a[i].xi,sx3+=a[i].xi*a[i].xi*a[i].xi,sx4+=a[i].xi*a[i].xi*a[i].xi*a[i].xi; sy+=a[i].yi,sxy+=a[i].xi*a[i].yi,sx2y+=a[i].xi*a[i].xi*a[i].yi; } db d = n*(sx2*sx4-sx3*sx3)-sx*(sx*sx4-sx3*sx2)+sx2*(sx*sx3-sx2*sx2); db da = sy*(sx2*sx4-sx3*sx3)-sx*(sxy*sx4-sx3*sx2y)+sx2*(sxy*sx3-sx2*sx2y); db dB = n*(sxy*sx4-sx3*sx2y)-sy*(sx*sx4-sx3*sx2)+sx2*(sx*sx2y-sxy*sx2); db dc = n*(sx2*sx2y-sxy*sx3)-sx*(sx*sx2y-sxy*sx2)+sy*(sx*sx3-sx2*sx2); qa = da/d,qb = dB/d,qc = dc/d; db yb = sy/n,tot = 0,res = 0; for(int i = 0;i < n;++i){ res+=(a[i].yi-(qa+qb*a[i].xi+qc*a[i].xi*a[i].xi))*(a[i].yi-(qa+qb*a[i].xi+qc*a[i].xi*a[i].xi)); tot+=(a[i].yi-yb)*(a[i].yi-yb); } qr2 = 1.0-res/tot; } inline void solve_sqrt(){ db sx = 0,sy = 0,sx2 = 0,sxy = 0; for(int i = 0;i < n;++i){ sx+=sqrt(a[i].xi),sy+=a[i].yi; sx2+=a[i].xi,(sxy+=sqrt(a[i].xi)*a[i].yi); } db xb = sx/n,yb = sy/n; sb = (sxy-n*xb*yb)/(sx2-n*xb*xb); sa = yb-sb*xb; db tot = 0,res = 0; for(int i = 0;i < n;++i){ res+=(a[i].yi-(sa+sb*sqrt(a[i].xi)))*(a[i].yi-(sa+sb*sqrt(a[i].xi))); tot+=(a[i].yi-yb)*(a[i].yi-yb); } sr2 = 1.0-res/tot; } signed main(){ freopen("data.in","r",stdin); // freopen("task.out","w",stdout); HACKILLER_FAST; cin>>n; for(int i = 0;i < n;++i) cin>>a[i].xi>>a[i].yi; solve_linear(); solve_quadratic(); solve_sqrt(); cout<<fixed<<setprecision(6); cout<<"Linear model: y = "<<la; if(lb >= 0) cout<<" + "<<lb<<" * x"; else cout<<" - "<<-lb<<" * x"; cout<<", R^2 = "<<lr2<<'\n'; cout<<"Quadratic model: y = "<<qa; if(qb >= 0) cout<<" + "<<qb<<" * x"; else cout<<" - "<<-qb<<" * x"; if(qc >= 0) cout<<" + "<<qc<<" * x^2"; else cout<<" - "<<-qc<<" * x^2"; cout<<", R^2 = "<<qr2<<'\n'; cout<<"Square-root model: y = "<<sa; if(sb >= 0) cout<<" + "<<sb<<" * sqrt(x)"; else cout<<" - "<<-sb<<" * sqrt(x)"; cout<<", R^2 = "<<sr2<<'\n'; return 0; } //No more; No less^_^