python中 mini-batch梯度下降的代码如何引用
时间: 2024-12-24 16:35:20 浏览: 6
在Python中,Mini-Batch梯度下降是一种优化算法,用于训练机器学习模型,特别是深度学习网络。它会在每次迭代时处理一小部分数据(mini-batch),而不是一次性使用所有数据。以下是一个简单的例子,展示如何在PyTorch库中实现Mini-Batch梯度下降:
```python
import torch
import torch.nn as nn
import torch.optim as optim
# 假设我们有一个线性回归模型
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
# 创建模型、输入张量和标签张量
model = LinearRegressionModel(input_dim=10, output_dim=1)
inputs = torch.randn(100, 10) # 100个样本,每个样本有10维特征
labels = torch.randn(100, 1) # 目标值
# 使用Adam优化器(包含Mini-Batch支持)
optimizer = optim.SGD(model.parameters(), lr=0.01, batch_size=32) # 学习率和批量大小
for epoch in range(num_epochs): # 设定总迭代次数
for i in range(0, len(inputs), optimizer.batch_size): # 循环处理mini-batches
batch_inputs = inputs[i:i+optimizer.batch_size]
batch_labels = labels[i:i+optimizer.batch_size]
# 前向传播计算损失
outputs = model(batch_inputs)
loss = nn.MSELoss()(outputs, batch_labels)
# 反向传播并更新权重
optimizer.zero_grad() # 清除上一步梯度
loss.backward()
optimizer.step() # 进行梯度下降
阅读全文