pytorch 图像分类 RESNET
时间: 2023-09-23 20:06:43 浏览: 183
你可以使用PyTorch中的预训练的ResNet模型来进行图像分类。下面是一个使用ResNet的示例代码:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的ResNet模型
resnet = models.resnet50(pretrained=True)
resnet.eval()
# 预处理图像
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载图像
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path)
# 预处理图像并添加批次维度
input_tensor = preprocess(image).unsqueeze(0)
# 使用ResNet进行推理
with torch.no_grad():
output = resnet(input_tensor)
# 加载类别标签
LABELS_URL = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
LABELS = torch.hub.load_state_dict_from_url(LABELS_URL)
# 获取预测结果
_, predicted_idx = torch.max(output, 1)
predicted_label = LABELS[predicted_idx.item()]
print('预测结果:', predicted_label)
```
确保将`path/to/your/image.jpg`替换为你自己的图像路径。此外,你还需要安装Pillow和torchvision库,可以通过以下命令进行安装:
```
pip install Pillow torchvision
```
这段代码会加载预训练的ResNet-50模型,对输入图像进行预测,并输出预测结果。请注意,这是一个简单的示例,你可以根据需要进行修改和扩展。
阅读全文