torch.no_grad()
时间: 2024-05-04 19:15:10 浏览: 96
torch.no_grad() is a context manager that disables gradient calculation. This is useful when we don't want to update the model's parameters during inference, or when we want to reduce memory usage by not keeping track of the gradients. When a tensor is created inside a no_grad() block, it requires_grad is automatically set to False. This means that any subsequent operations on this tensor will also not have gradients computed.
Here is an example:
```
import torch
x = torch.randn(3, 5, requires_grad=True)
y = torch.randn(3, 5)
with torch.no_grad():
z = x + y # gradients won't be calculated for z
z.sum().backward() # this will raise an error because z does not have gradients
```
In the above example, we create two tensors x and y. x has requires_grad set to True, which means that gradients will be calculated during backpropagation. However, inside the no_grad() block, we add x and y to create z. Since the addition operation is inside the no_grad() block, gradients will not be calculated for z. Finally, when we try to compute the gradients of z, it will raise an error because z does not have gradients.
阅读全文