with torch.no_grad():
时间: 2024-05-06 22:19:56 浏览: 104
The `torch.no_grad()` context manager is used to turn off gradient computation during model inference, i.e., when we are evaluating a model on test data or making predictions. When inside the `torch.no_grad()` context, any operation that requires gradient computation will not be tracked by PyTorch's autograd engine. This can help reduce memory usage and speed up computation, as we don't need to store intermediate values for backpropagation.
Here's an example of how we might use `torch.no_grad()` during model inference:
```
model.eval() # Set the model to evaluation mode
with torch.no_grad():
for inputs, labels in test_loader:
# Perform forward pass to get predictions
outputs = model(inputs)
# Compute test set accuracy
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy: {:.2f}%'.format(100 * correct / total))
```
In this example, we use `torch.no_grad()` to compute the test set accuracy without tracking the gradients of the model's parameters. This can help speed up the computation and reduce memory usage, especially if we have a large test set.
阅读全文