with torch.no_grad():
时间: 2023-09-20 21:15:10 浏览: 85
This is a context manager in PyTorch that disables gradient calculation. It can be used when we don't want to compute gradients during inference or when we want to save memory during training.
For example, when we are evaluating a trained model on a test dataset, we don't need to compute gradients because we are not updating the model parameters. In this case, we can wrap our evaluation code with `torch.no_grad()` to disable gradient computation and save memory.
Here's an example:
```
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
output = model(data)
loss = criterion(output, target)
test_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_accuracy = 100. * correct / len(test_loader.dataset)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset), test_accuracy))
```
In this code snippet, we are evaluating a trained model on a test dataset. We wrap the evaluation code with `torch.no_grad()` to disable gradient computation. This ensures that the memory used for storing gradients is not wasted during evaluation.
阅读全文