resnet34实现5分类的混淆矩阵代码,混淆矩阵保存为图片
时间: 2023-11-09 09:06:02 浏览: 171
以下是参考代码,需要根据自己的数据集和模型进行修改:
```python
import torch
from torch import nn
from torchvision import models
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
# 加载数据集
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
test_dataset = ImageFolder('test/', transform=data_transforms)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 加载模型
model = models.resnet34(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 5)
model.load_state_dict(torch.load('resnet34.pth'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# 预测并计算混淆矩阵
y_true = []
y_pred = []
model.eval()
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
cm = confusion_matrix(y_true, y_pred)
# 可视化混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(test_dataset.classes))
plt.xticks(tick_marks, test_dataset.classes, rotation=45)
plt.yticks(tick_marks, test_dataset.classes)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
for i in range(len(test_dataset.classes)):
for j in range(len(test_dataset.classes)):
plt.text(j, i, str(cm[i][j]), ha='center', va='center')
plt.savefig('confusion_matrix.png')
```
上述代码中,`test/` 文件夹是测试集的路径,`resnet34.pth` 是训练好的 ResNet34 模型的权重文件路径。混淆矩阵会保存为 `confusion_matrix.png` 文件。
阅读全文