Pytorch 梯度下降算法【3/9】小批量梯度下降(Mini-Batch Gradient Descent)
【摘要】 在 PyTorch 中,小批量梯度下降法(Mini-Batch Gradient Descent)是梯度下降算法的一种变体。与批量梯度下降法(BGD)使用整个训练集的梯度进行参数更新不同,Mini-Batch Gradient Descent 在每次参数更新时使用一小批样本的梯度来更新模型参数。下面我将通过一个简单的线性回归问题来演示如何在 PyTorch 中使用小批量梯度下降法。首先,我们...
定义
在 PyTorch 中,小批量梯度下降法(Mini-Batch Gradient Descent)是梯度下降算法的一种变体。与批量梯度下降法(BGD)使用整个训练集的梯度进行参数更新不同,Mini-Batch Gradient Descent 在每次参数更新时使用一小批样本的梯度来更新模型参数。
模型示意图
由于mini-batch每次仅使用数据集中的一部分进行梯度下降,所以每次下降并不是严格按照朝最小方向下降,但是总体下降趋势是朝着最小方向,上图可以明显看出两者之间的区别。
案例实战
下面我将通过一个简单的线性回归问题来演示如何在 PyTorch 中使用小批量梯度下降法。
首先,我们需要导入 PyTorch 库并准备数据:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 随机生成一些简单的线性数据
np.random.seed(42)
X = np.random.rand(100, 1) # 100个输入样本
y = 2 * X + 1 + 0.1 * np.random.randn(100, 1) # 添加随机噪声的目标输出
接下来,我们定义一个简单的线性模型,并使用小批量梯度下降法进行优化:
# 将数据转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
# 定义线性模型
class LinearModel(nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = nn.Linear(1, 1) # 输入特征数为1,输出特征数为1
def forward(self, x):
return self.linear(x)
# 创建模型实例和优化器
model = LinearModel()
criterion = nn.MSELoss() # 使用均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01) # 使用小批量梯度下降法
# 进行模型训练
num_epochs = 100
batch_size = 10
for epoch in range(num_epochs):
# 将数据切分成小批量
num_samples = X_tensor.shape[0]
indices = torch.randperm(num_samples)
for i in range(0, num_samples, batch_size):
batch_indices = indices[i:i+batch_size]
batch_X = X_tensor[batch_indices]
batch_y = y_tensor[batch_indices]
# 清零梯度
optimizer.zero_grad()
# 前向传播
y_pred = model(batch_X)
# 计算损失
loss = criterion(y_pred, batch_y)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
# 打印训练后的模型参数
print("训练后的模型参数:", model.state_dict())
在上述代码中,我们使用了 torch.optim.SGD
创建了一个小批量梯度下降法的优化器,并将其应用于训练过程中。每一轮迭代中,我们将数据随机切分成小批量,然后使用这些小批量的样本来计算梯度和更新模型参数。Mini-Batch Gradient Descent 综合了随机梯度下降(SGD)和批量梯度下降(BGD)的优点,既减少了计算成本,又保持了相对稳定和较快的训练过程。
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)