反向传播代码pytorch
时间: 2023-10-23 20:09:32 浏览: 116
好的,以下是一个简单的反向传播代码示例:
```python
import torch
# 定义输入和目标输出
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
y_true = torch.tensor([2.0, 4.0, 6.0, 8.0])
# 定义模型
w = torch.tensor([0.0], requires_grad=True)
b = torch.tensor([0.0], requires_grad=True)
# 定义损失函数
loss_fn = torch.nn.MSELoss()
# 定义优化器
optimizer = torch.optim.SGD([w, b], lr=0.01)
# 训练模型
for epoch in range(100):
# 前向传播
y_pred = w * x + b
# 计算损失
loss = loss_fn(y_pred, y_true)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 清空梯度
optimizer.zero_grad()
# 打印损失
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 100, loss.item()))
# 输出最终的参数值
print('w:', w.item())
print('b:', b.item())
```
阅读全文