Pytorch 梯度下降算法【4/9】动量梯度下降(Momentum Gradient Descent)

举报
林欣 发表于 2023/07/25 16:14:14 2023/07/25
【摘要】 在 PyTorch 中,动量梯度下降(Momentum Gradient Descent)是梯度下降算法的一种改进版。与传统的随机梯度下降(SGD)只考虑当前梯度方向不同,动量梯度下降考虑了历史梯度方向,类似于模拟物体滚下斜坡时的惯性效果,使得参数更新更加平滑和稳定。下面我将通过一个简单的线性回归问题来演示如何在 PyTorch 中使用动量梯度下降法。首先,我们需要导入 PyTorch 库并...

57194dcbce77c1c1eb0bf47e2fbec2d1_1690272690210861000.png

在 PyTorch 中,动量梯度下降(Momentum Gradient Descent)是梯度下降算法的一种改进版。与传统的随机梯度下降(SGD)只考虑当前梯度方向不同,动量梯度下降考虑了历史梯度方向,类似于模拟物体滚下斜坡时的惯性效果,使得参数更新更加平滑和稳定。下面我将通过一个简单的线性回归问题来演示如何在 PyTorch 中使用动量梯度下降法。

首先,我们需要导入 PyTorch 库并准备数据:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 随机生成一些简单的线性数据
np.random.seed(42)
X = np.random.rand(100, 1)  # 100个输入样本
y = 2 * X + 1 + 0.1 * np.random.randn(100, 1)  # 添加随机噪声的目标输出

接下来,我们定义一个简单的线性模型,并使用动量梯度下降法进行优化:

# 将数据转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# 定义线性模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入特征数为1,输出特征数为1

    def forward(self, x):
        return self.linear(x)

# 创建模型实例和优化器
model = LinearModel()
criterion = nn.MSELoss()  # 使用均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 使用动量梯度下降法

# 进行模型训练
num_epochs = 100

for epoch in range(num_epochs):
    # 清零梯度
    optimizer.zero_grad()
    # 前向传播
    y_pred = model(X_tensor)
    # 计算损失
    loss = criterion(y_pred, y_tensor)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 打印训练后的模型参数
print("训练后的模型参数:", model.state_dict())

在上述代码中,我们使用了 torch.optim.SGD 创建了一个动量梯度下降法的优化器,并将其应用于训练过程中。通过在创建优化器时指定 momentum 参数为0.9,我们添加了动量效果。在每次参数更新时,动量梯度下降法会考虑历史梯度方向,并使用动量项来调整参数更新方向,使得更新更加平滑和稳定。动量梯度下降法通常能够加速收敛,并减少参数更新的震荡。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。