autograd.function实现relu
时间: 2024-02-05 16:03:20 浏览: 120
下面是使用`autograd.Function`实现ReLU的示例代码:
```python
import torch
class ReLUFunction(torch.autograd.Function):
"""
自定义ReLU函数,继承自autograd.Function
"""
@staticmethod
def forward(ctx, input):
"""
前向传播函数,保存输入张量并计算ReLU
"""
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
"""
反向传播函数,计算梯度
"""
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# 创建输入张量
x = torch.randn(2, 3, requires_grad=True)
# 使用自定义ReLU函数进行前向传播
relu = ReLUFunction.apply
y = relu(x)
# 计算梯度并输出
y.sum().backward()
print(x.grad)
```
这里我们自定义了一个`ReLUFunction`类,继承自`autograd.Function`,并实现了`forward`和`backward`方法,分别对应前向传播和反向传播。在前向传播中,我们使用`clamp`函数计算ReLU,并保存输入张量到`ctx`中,以便在反向传播中使用。在反向传播中,我们首先从`ctx`中取出保存的输入张量,然后根据梯度传播原理计算梯度。最后,我们使用`apply`方法调用自定义函数进行前向传播。
阅读全文