resnet50 pridict pytorch
时间: 2023-06-17 13:04:14 浏览: 103
resnet pytorch代码
以下是在 PyTorch 中使用 ResNet50 进行预测的示例代码:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的 ResNet50 模型
model = models.resnet50(pretrained=True)
model.eval()
# 对输入图像进行预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载图像并进行预测
img = Image.open('image.jpg')
img = transform(img)
img = img.unsqueeze(0)
output = model(img)
# 输出预测结果
_, predicted = torch.max(output.data, 1)
print('Predicted class:', predicted.item())
```
在上面的示例中,我们首先加载了预训练的 ResNet50 模型,并将其设置为评估模式(`model.eval()`)。然后,我们定义了一个变换(`transform`)来对输入图像进行预处理,包括缩放、中心裁剪和归一化操作。接下来,我们加载图像并对其进行预测。最后,我们使用 `torch.max()` 函数找到输出向量中最大值的索引,从而得到预测结果。
阅读全文