我怎么样加载已训练好的模型进行预测
时间: 2023-08-10 21:09:17 浏览: 53
要加载已训练好的模型进行预测,你需要使用 PyTorch 提供的 `torch.load` 函数加载模型的参数和状态字典。下面是一个加载模型并进行预测的示例:
```python
import torch
from torchvision import models, transforms
from PIL import Image
# 定义图像变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载已训练好的模型
model = models.resnet50(pretrained=True)
model.eval()
# 加载图像并进行预处理
image = Image.open('image.jpg')
image = transform(image).unsqueeze(0)
# 加载模型参数和状态字典
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['state_dict'])
# 进行预测
with torch.no_grad():
output = model(image)
# 处理预测结果
_, predicted_idx = torch.max(output, 1)
predicted_label = predicted_idx.item()
print(f"Predicted label: {predicted_label}")
```
在这个示例中,首先定义了图像的预处理变换,然后使用 `models.resnet50` 加载了一个预训练的 ResNet-50 模型,并将其设为评估模式。接下来,加载了要预测的图像,并进行了相同的预处理操作。然后,使用 `torch.load` 加载已训练模型的参数和状态字典,并使用 `load_state_dict` 将参数加载到模型中。最后,通过将图像传入模型进行预测,并处理预测结果。
请注意,这只是一个示例,你需要根据你的具体情况和模型进行相应的修改。确保模型的架构和预处理操作与训练时保持一致。
希望这个示例能帮到你!如有任何疑问,请随时提问。