[吴恩达机器学习]3·正规方程解多元线性回归

吴恩达机器学习系列课程:https://www.bilibili.com/video/BV164411b7dx

正规方程

说白了,这就是用我们在微积分中学习的多元微分学知识直接解出答案。

对于代价函数: \[ J(\theta)=J(\theta_0, \theta_1,\cdots, \theta_n) \] 如果它是连续的,那么要求出它的最小值,只需要令各偏导为零,就能得到 \(\theta\) 的值: \[ \frac{\partial J}{\partial \theta_j}=0,\quad j=0,1,\cdots,n \] 或写作向量形式: \[ \frac{\partial J}{\partial \theta}=\vec 0 \] 下面我们就来对多元线性回归的代价函数解一解。


多元线性回归的代价函数为: \[ J(\theta)=\frac{1}{2m}\sum_{i=1}^m\left(\theta^Tx^{(i)}-y^{(i)}\right)^2 \] 于是其偏导函数为: \[ \frac{\partial J}{\partial \theta}=\frac{1}{m}\sum_{i=1}^m\left(\theta^Tx^{(i)}-y^{(i)}\right)x^{(i)} \] 要使之为零向量,只能是: \[ \theta^Tx^{(i)}=y^{(i)},\quad i=1,2,\cdots,m \] 恒成立。写作矩阵为: \[ X\theta=y \] 其中, \[ X=\begin{bmatrix}x_0^{(1)}&x_1^{(1)}&\cdots& x_n^{(1)}\\x_0^{(2)}&x_1^{(2)}&\cdots& x_n^{(2)}\\\vdots&\vdots&\ddots&\vdots\\x_0^{(m)}&x_1^{(m)}&\cdots& x_n^{(m)}\end{bmatrix}=\begin{bmatrix}{x^{(1)}}^T\\{x^{(2)}}^T\\\vdots\\{x^{(m)}}^T\\\end{bmatrix},\quad y=\begin{bmatrix}y^{(1)}\\y^{(2)}\\\vdots\\y^{(m)}\end{bmatrix} \] 两边同时乘以 \(X^T\),假设 \(X^TX\) 可逆,解得: \[ \theta=(X^TX)^{-1}X^Ty \] 这就是数学上多元线性回归方程的精确解。


这里,\(X^TX\) 是一个 \((n+1)\times(n+1)\) 的矩阵,因此直接计算 \(\theta\) 的复杂度是 \(O(n^3)\) 的,如果 \(n\) 不是很大,这是有效的,但是如果 \(n\) 达到了 \(10^4,10^5\) 或更高级别,就需要使用梯度下降了。

实现

仍然对第二篇中的多元线性回归数据进行求解。

代码很简洁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np

def J(T, X, Y):
res = 0
for i in range(m):
res += (np.matmul(T.T, X[i:i+1, :].T) - Y[i:i+1, :]) ** 2
res /= 2 * m;
return res

data = np.genfromtxt("ex1data2.txt", delimiter = ',')
(m, n) = data.shape
X = np.column_stack((np.ones((m, 1)), data[:, :-1]))
Y = data[:, -1:]
T = np.matmul(np.matmul(np.linalg.inv(np.matmul(X.T, X)), X.T), Y)
print(T)
print(J(T, X, Y))

很快给出了结果:\(\theta=(89597.9095428,\,139.21067402,\,-8738.01911233)^T\).

不可逆情形

前一节的推导基于 \(X^TX\) 可逆的假设,如若不可逆,我们只需将代码中的 inv() 换成 pinv() 求出伪逆矩阵即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np

def J(T, X, Y):
res = 0
for i in range(m):
res += (np.matmul(T.T, X[i:i+1, :].T) - Y[i:i+1, :]) ** 2
res /= 2 * m;
return res

data = np.genfromtxt("ex1data2.txt", delimiter = ',')
(m, n) = data.shape
X = np.column_stack((np.ones((m, 1)), data[:, :-1]))
Y = data[:, -1:]
T = np.matmul(np.matmul(np.linalg.pinv(np.matmul(X.T, X)), X.T), Y)
print(T)
print(J(T, X, Y))

[吴恩达机器学习]3·正规方程解多元线性回归
https://xyfjason.github.io/blog-main/2020/12/22/吴恩达机器学习-3·正规方程解多元线性回归/
作者
xyfJASON
发布于
2020年12月22日
许可协议