基于pytorch实现猫狗分类测试代码
时间: 2023-09-02 21:12:06 浏览: 156
以下是使用PyTorch实现猫狗分类的测试代码:
```python
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载模型
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 定义图像预处理函数
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载测试图片
img_path = 'test.jpg'
img = Image.open(img_path)
# 预处理图片并转换为模型输入
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0)
# 使用模型进行预测
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output, 1)
# 打印预测结果
if predicted.item() == 0:
print('猫')
else:
print('狗')
```
其中,`model.pth`是训练好的模型文件,`test.jpg`是待测试的图片文件。在运行代码前需要先安装PyTorch和Pillow库。
阅读全文