torch.nn.Module.backward()
时间: 2024-10-25 18:07:41 浏览: 21
`torch.nn.Module.backward()` 是PyTorch中的一个核心功能,用于反向传播(backpropagation)。当你通过模型执行前向传递(forward pass),计算损失(loss)之后,你可以调用 `.backward()` 方法来开始反向传播过程。这个方法会自动对模型的所有参数应用链式法则,计算每个参数对于整个损失函数的梯度。
下面是一个简单示例[^1]:
```python
# 假设我们有一个自定义的Model类
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# 添加一些可学习的参数
def forward(self, input):
# 前向传播逻辑
output = ... # 模型计算结果
return output
# 初始化并创建实例
model = MyModule()
input = torch.randn(1, 3, 224, 224)
output = model(input)
# 计算损失
loss = F.cross_entropy(output, target) # 假设target是正确的类别标签
# 开始反向传播
loss.backward()
# 此时,model.parameters()中的每个参数都将包含其梯度
for param in model.parameters():
print(param.grad)
```
在这个过程中,`.grad`属性存储了参数相对于当前损失的梯度。注意,如果你有非可微操作(如ReLU激活函数在某些点上梯度为零的情况),`backward()`默认不会改变这些位置的梯度。若需特殊处理,可能需要自定义`Function`类或者使用`torch.autograd.Function`的`register_hook()`方法。
阅读全文