resnet18推理代码
时间: 2024-04-23 07:19:32 浏览: 112
ResNet代码
ResNet-18是一个经典的深度卷积神经网络模型,常用于图像分类任务。下面是一个简单的ResNet-18推理代码示例:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的ResNet-18模型
model = models.resnet18(pretrained=True)
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载并预处理图像
image_path = 'path_to_your_image.jpg'
image = Image.open(image_path)
image = transform(image)
image = torch.unsqueeze(image, 0)
# 使用ResNet-18进行推理
with torch.no_grad():
output = model(image)
# 获取预测结果
_, predicted_idx = torch.max(output, 1)
predicted_label = predicted_idx.item()
# 打印预测结果
print(f"Predicted label: {predicted_label}")
```
请注意,上述代码中的`'path_to_your_image.jpg'`需要替换为你自己的图像路径。此外,还需要安装`torch`和`torchvision`库,并确保已经下载了ResNet-18的预训练权重。
阅读全文