torchvision怎么使用预训练模型?
时间: 2024-05-02 15:03:49 浏览: 57
torchvision resnet18 计算相似度
Torchvision提供了一些预先训练好的模型,可以直接使用。下面以ResNet-18为例,介绍如何使用预训练模型:
```python
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 将模型设置为评估模式
model.eval()
# 随机生成一张图片
input_image = torch.rand(1, 3, 224, 224)
# 使用模型进行预测
output = model(input_image)
print(output)
```
首先,我们导入了PyTorch和Torchvision库。然后,我们使用`torchvision.models.resnet18`函数加载了预训练模型。这个函数返回一个模型实例,我们可以对它进行微调或使用它进行预测。
接下来,我们将模型设置为评估模式,这会关闭一些训练中使用的技巧,例如Dropout和Batch Normalization。这样可以确保模型在预测时的输出稳定。
然后,我们随机生成了一张图片作为输入。这里的图片是一个4D张量,包含了一个batch中的所有图片。我们只生成了一张图片,所以第一个维度的大小为1。
最后,我们使用模型进行预测,并打印了输出。注意,输出是一个2D张量,包含了每个类别的分数。可以使用`torch.nn.functional.softmax`函数将它们转换为概率。
阅读全文