深度学习基础:2.最小二乘法

举报
zstar 发表于 2022/08/06 00:50:29 2022/08/06
【摘要】 最小二乘法代数表示方法 假设多元线性方程有如下形式: f ( ...

最小二乘法代数表示方法

假设多元线性方程有如下形式:
f ( x ) = w 1 x 1 + w 2 x 2 + . . . + w d x d + b f(x) = w_1x_1+w_2x_2+...+w_dx_d+b f(x)=w1x1+w2x2+...+wdxd+b
w = ( w 1 , w 2 , . . . w d ) w = (w_1,w_2,...w_d) w=(w1,w2,...wd) x = ( x 1 , x 2 , . . . x d ) x = (x_1,x_2,...x_d) x=(x1,x2,...xd),则上式可写为
f ( x ) = w T x + b f(x) = w^Tx+b f(x)=wTx+b
多元线性回归的最小二乘法的代数法表示较为复杂,此处先考虑简单线性回归的最小二乘法表示形式。在简单线性回归中,w只包含一个分量,x也只包含一个分量,我们令此时的 x i x_i xi就是对应的自变量的取值,此时求解过程如下

优化目标可写为
S S E = ∑ i = 1 m ( f ( x i ) − y i ) 2 = E ( w , b ) SSE = \sum^m_{i=1}(f(x_i)-y_i)^2 = E_(w,b) SSE=i=1m(f(xi)yi)2=E(w,b)
通过偏导为0求得最终结果的最小二乘法求解过程为:
KaTeX parse error: No such environment: align at position 9: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…

KaTeX parse error: No such environment: align at position 9: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…

求得:
w = ∑ i = 1 m y i ( x i − x ˉ ) ∑ i = 1 m x i 2 − 1 m ( ∑ i = 1 m x i ) 2 w = \frac{\sum^m_{i=1}y_i(x_i-\bar{x}) }{\sum^m_{i=1}x^2_i-\frac{1}{m}(\sum^m_{i=1}x_i)^2 } w=i=1mxi2m1(i=1mxi)2i=1myi(xixˉ)

b = 1 m ∑ i = 1 m ( y i − w x i ) b = \frac{1}{m}\sum^m_{i=1}(y_i-wx_i) b=m1i=1m(yiwxi)

#最小二乘法的矩阵表示形式

设多元线性回归方程为:
f ( x ) = w 1 x 1 + w 2 x 2 + . . . + w d x d + b f(x) = w_1x_1+w_2x_2+...+w_dx_d+b f(x)=w1x1+w2x2+...+wdxd+b

w ^ = ( w 1 , w 2 , . . . , w d , b ) \hat w = (w_1,w_2,...,w_d,b) w^=(w1,w2,...,wd,b)

x ^ = ( x 1 , x 2 , . . . , x d , 1 ) \hat x = (x_1,x_2,...,x_d,1) x^=(x1,x2,...,xd,1)


f ( x ) = w ^ ∗ x ^ T f(x) = \hat w * \hat x^T f(x)=w^x^T
有多个y值,则所有x值可以用矩阵X进行表示:
X = [ x 11 x 12 . . . x 1 d 1 x 21 x 22 . . . x 2 d 1 . . . . . . . . . . . . 1 x m 1 x m 2 . . . x m d 1 ] X = \left [ x 11 x 12 . . . x 1 d 1 x 21 x 22 . . . x 2 d 1 . . . . . . . . . . . . 1 x m 1 x m 2 . . . x m d 1 \right] X=x11x21...xm1x12x22...xm2............x1dx2d...xmd1111
y也可用m行1列的矩阵表示:
y = [ y 1 y 2 . . . y m ] y = \left [ y 1 y 2 . . . y m \right] y=y1y2...ym
此时,SSE表示为:
S S E = ∣ ∣ y − X w ^ T ∣ ∣ 2 2 = ( y − X w ^ T ) T ( y − X w ^ T ) = E ( w ^ ) SSE = ||y - X\hat w^T||_2^2 = (y - X\hat w^T)^T(y - X\hat w^T) = E(\hat w) SSE=yXw^T22=(yXw^T)T(yXw^T)=E(w^)
根据最小二乘法的求解过程,令 E ( w ^ ) E(\hat w) E(w^) w ^ \hat w w^求导方程取值为0,有
KaTeX parse error: No such environment: equation at position 879: …数,有如下规则: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ \\\frac{\parti…
简单记忆矩阵的求导,自身转置对自身求导=其它项的转置;

其中:
∂ ∣ ∣ y − X w ^ T ∣ ∣ 2 2 \partial{||\boldsymbol{y} - \boldsymbol{X\hat w^T}||_2}^2 yXw^T22
下方的2表示L2范式,即里面的内容相乘之后开根号,右上平方之后,消除根号。

进一步可得
X T X w ^ T = X T y X^TX\hat w^T = X^Ty XTXw^T=XTy
要使得此式有解,等价于 X T X X^TX XTX(也被称为矩阵的交叉乘积crossprod存在逆矩阵,若存在,则可解出)
w ^ T = ( X T X ) − 1 X T y \hat w ^T = (X^TX)^{-1}X^Ty w^T=(XTX)1XTy

最小二乘法的编程实现

例子:

X = A = [ 1 1 3 1 ] X = A = \left [ 1 1 3 1 \right] X=A=[1311]

y = B = [ 2 4 ] y = B = \left [ 2 4 \right] y=B=[24]

w ^ T = X T = [ a b ] \hat w ^T = X^T = \left [ a b \right] w^T=XT=[ab]

手动实现

X = A
X

  
 
  • 1
  • 2
tensor([[1., 1.],
        [3., 1.]])

  
 
  • 1
  • 2
y = B
y

  
 
  • 1
  • 2
tensor([[2.],
        [4.]])

  
 
  • 1
  • 2
X.t()

  
 
  • 1
tensor([[1., 3.],
        [1., 1.]])

  
 
  • 1
  • 2
w = torch.mm(torch.mm(torch.inverse(torch.mm(X.t(),X)),X.t()),y)

  
 
  • 1

这里直接套用上面推导出来的公式

w

  
 
  • 1
tensor([[1.0000],
        [1.0000]])

  
 
  • 1
  • 2

调用函数求解

torch.lstsq(y, X)

  
 
  • 1
torch.return_types.lstsq(
solution=tensor([[1.0000],
        [1.0000]]),
QR=tensor([[-3.1623, -1.2649],
        [ 0.7208, -0.6325]]))

  
 
  • 1
  • 2
  • 3
  • 4
  • 5

对于lstsq函数来说,第一个参数是因变量张量,第二个参数是自变量张量,并且同时返回结果还包括QR矩阵分解的结果。

补充知识点:范数的计算

求解L2范数:

# 默认情况,求解L2范数,个元素的平方和开平方
torch.linalg.norm(t)

  
 
  • 1
  • 2

求解L1范数:

# 输入参数,求解L1范数,个元素的绝对值之和
torch.linalg.norm(t, 1)

  
 
  • 1
  • 2

总结

最小二乘法计算快速,但条件苛刻,需满足X存在逆矩阵。

文章来源: zstar.blog.csdn.net,作者:zstar-_,版权归原作者所有,如需转载,请联系作者。

原文链接:zstar.blog.csdn.net/article/details/120610234

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。