混淆矩阵pytorch
时间: 2023-08-30 21:09:44 浏览: 99
confusion_混淆矩阵、pytorch、模型_混淆矩阵pytorch_混淆矩阵_
5星 · 资源好评率100%
混淆矩阵(Confusion Matrix)是用来评价分类模型预测结果的一种常用方法,它可以将模型预测的结果与真实标签进行比较,统计出模型分类正确和错误的情况。在 PyTorch 中,可以使用 sklearn 库的 confusion_matrix 函数来创建混淆矩阵。
以下是一个示例代码,它使用了一个简单的分类模型来对 CIFAR-10 数据集进行训练,并生成混淆矩阵:
```python
import torch
import torchvision
import numpy as np
from sklearn.metrics import confusion_matrix
# 定义分类模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5)
self.pool = torch.nn.MaxPool2d(2, 2)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.nn.functional.relu(self.conv1(x)))
x = self.pool(torch.nn.functional.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载 CIFAR-10 数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=torchvision.transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# 加载模型
net = Net()
net.load_state_dict(torch.load('cifar_net.pth'))
# 测试模型
net.eval()
preds = []
labels = []
with torch.no_grad():
for data in testloader:
inputs, targets = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
preds.extend(predicted.numpy())
labels.extend(targets.numpy())
# 生成混淆矩阵
cm = confusion_matrix(labels, preds)
print(cm)
```
在上面的示例代码中,我们首先定义了一个简单的分类模型(Net),然后加载了 CIFAR-10 数据集和训练好的模型参数。接着,在测试集上使用模型进行预测,并将预测结果和真实标签存储在 preds 和 labels 列表中。最后,使用 confusion_matrix 函数生成混淆矩阵,并输出结果。
输出的混淆矩阵是一个 10x10 的矩阵,其中每行表示真实标签,每列表示预测标签。例如,第 i 行第 j 列表示真实标签为 i,预测标签为 j 的样本数量。对角线上的元素表示分类正确的样本数,非对角线上的元素表示分类错误的样本数。
阅读全文