pytorch resnet50加载训练好的模型pridect
时间: 2023-06-17 08:04:11 浏览: 151
以下是使用 PyTorch 加载已经训练好的 ResNet50 模型进行预测的示例代码:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的 ResNet50 模型
model = models.resnet50(pretrained=True)
model.eval()
# 定义预处理函数
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 读取待预测的图片并进行预处理
image = Image.open("test.jpg")
image_tensor = preprocess(image)
image_tensor = image_tensor.unsqueeze(0)
# 使用模型进行预测
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
# 打印预测结果
print("Predicted class:", predicted.item())
```
上述代码中,首先使用 `models.resnet50(pretrained=True)` 加载预训练的 ResNet50 模型,并将其设置为评估模式(`model.eval()`)。接着,定义了一个预处理函数 `preprocess`,用于将待预测的图片进行预处理。在读取待预测的图片后,将其传入预处理函数,得到预处理后的图片张量,并使用 `unsqueeze(0)` 将其扩展为一个 batch。最后,使用 `model` 进行预测,并将预测结果的类别打印出来。
阅读全文