用resnet18训练完模型后,如何查看测试的分类结果,有代码示例吗
时间: 2024-03-17 20:39:56 浏览: 87
可以使用PyTorch中的torchvision包中的函数来查看测试的分类结果。具体步骤如下:
1. 加载训练好的模型和测试集数据。
```python
import torch
import torchvision
from torchvision import models, transforms, datasets
# 加载训练好的模型
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load('resnet18_model.pth'))
# 加载测试集数据
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_dataset = datasets.ImageFolder('test', transform=data_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)
```
2. 定义一个函数来测试模型并输出分类结果。
```python
def test_model(model, test_loader):
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
correct_preds = 0
total_preds = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
correct_preds += torch.sum(preds == labels.data)
total_preds += len(labels)
accuracy = correct_preds.double() / total_preds
print('Accuracy on test set: {:.2%}'.format(accuracy))
```
3. 调用该函数来测试模型并输出分类结果。
```python
test_model(model, test_loader)
```
输出结果类似于:Accuracy on test set: 90.00%,表示测试集上的准确率为90%。
其中,“test”文件夹是存放测试集图片的文件夹,每个子文件夹代表一类图片,文件夹名即为该类别的标签名。
阅读全文