torch.no_grad()
时间: 2024-05-08 13:18:03 浏览: 104
torch.no_grad() is a context manager provided by PyTorch that disables gradient calculation. When we are working with deep learning models, we are often interested in computing gradients for the parameters of the model so that we can update them using an optimization algorithm. However, sometimes we may want to perform some operations on the model's output without computing gradients, such as when we are using the model for inference or evaluation. In such cases, we can use the torch.no_grad() context manager to disable gradient calculation, which can significantly reduce memory usage and speed up computation.
Here is an example of how to use torch.no_grad():
```
import torch
# create a tensor with requires_grad=True
x = torch.tensor([3., 4.], requires_grad=True)
# compute the gradient of a function
y = x.sum()
y.backward()
# update the parameter using an optimizer
optimizer = torch.optim.SGD([x], lr=0.1)
optimizer.step()
# use the tensor without gradient calculation
with torch.no_grad():
z = x * 2
print(z)
```
In this example, we create a tensor `x` with `requires_grad=True`, which means that gradients will be computed for this tensor. We then compute the gradient of a function `y` with respect to `x` using `y.backward()`. We also create an optimizer and update the parameter `x` using `optimizer.step()`. Finally, we use the `x` tensor without gradient calculation by wrapping the code inside the `with torch.no_grad():` context manager. In this block, we compute `z = x * 2`, but no gradients are computed for `z`.
阅读全文