用pytorch写出adam算法
时间: 2024-10-16 14:09:54 浏览: 23
pytorch实现的手写数字识别.zip
Adam是一种常用的优化算法,它结合了动量(momentum)和自适应学习率的优点。在PyTorch中,我们可以使用内置的`optim.Adam`模块来实现这个算法。以下是一个简单的例子:
```python
import torch
from torch import nn
from torch.optim import Adam
# 定义模型
model = nn.Linear(input_dim, output_dim) # 假设input_dim和output_dim分别为输入特征数和输出节点数
# 初始化权重张量
parameters = model.parameters()
# 创建Adam优化器,设定初始学习率和其他参数(如动量因子、指数衰减等)
optimizer = Adam(parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-8)
# 假设我们有一个loss函数
def compute_loss(output, target):
return nn.functional.mse_loss(output, target)
# 训练循环
for epoch in range(num_epochs):
for inputs, targets in train_dataloader:
optimizer.zero_grad() # 梯度清零
outputs = model(inputs)
loss = compute_loss(outputs, targets)
loss.backward() # 反向传播求梯度
optimizer.step() # 更新参数
阅读全文