pytorch如何使用已经训练好的resnet模型实现图片识别
时间: 2023-09-17 09:08:41 浏览: 97
Pytorch使用预训练模型进行图像分类
5星 · 资源好评率100%
使用已经训练好的resnet模型实现图片识别需要进行以下步骤:
1. 导入相关库和模型
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
# 加载resnet模型
model = models.resnet18(pretrained=True)
```
2. 预处理图片数据
```python
# 定义图片预处理方式
transform = transforms.Compose([
transforms.Resize(256), # 将图片大小调整为256×256
transforms.CenterCrop(224), # 从中心裁剪出224×224大小的图片
transforms.ToTensor(), # 将图片转换为张量
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # 标准化图片
std=[0.229, 0.224, 0.225])
])
# 载入测试图片并进行预处理
img = Image.open('test.jpg')
img = transform(img).unsqueeze(0) # 增加一维作为batch_size
```
3. 使用模型进行图片识别
```python
# 将模型设为评估模式
model.eval()
# 将图片张量输入模型,得到预测结果
with torch.no_grad():
output = model(img)
# 将预测结果转换为概率值,并得到最大概率对应的标签
prob = torch.nn.functional.softmax(output, dim=1)[0]
pred_label_idx = torch.argmax(prob).item()
```
4. 查看预测结果
```python
# 载入标签名
with open('imagenet_classes.txt') as f:
classes = [line.strip() for line in f.readlines()]
# 打印预测结果
print('预测结果:', classes[pred_label_idx])
```
阅读全文