resnet50图像分类代码
时间: 2023-12-31 22:24:00 浏览: 121
ResNet图像分类代码
以下是使用ResNet50进行图像分类的代码示例:
```python
import torch
import torchvision
from torchvision import transforms
# 加载预训练的ResNet50模型
model = torchvision.models.resnet50(pretrained=True).eval().cuda()
# 定义图像预处理的转换操作
tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载并预处理图像
image = Image.open('image.jpg')
image = tf(image).unsqueeze(0).cuda()
# 使用ResNet50进行图像分类
output = model(image)
# 获取预测结果
_, predicted_idx = torch.max(output, 1)
predicted_label = predicted_idx.item()
# 打印预测结果
print("Predicted label:", predicted_label)
```
请注意,上述代码中的`image.jpg`是待分类的图像文件路径,你需要将其替换为你自己的图像文件路径。
阅读全文