混淆矩阵pytorch实现
时间: 2024-04-22 20:21:12 浏览: 116
混淆矩阵是用于评估分类模型性能的一种常用工具,它可以展示模型在不同类别上的分类结果。在PyTorch中,可以通过以下步骤来实现混淆矩阵:
1. 导入必要的库:
```python
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
```
2. 定义模型和数据加载器:
```python
# 定义模型
model = YourModel()
# 加载测试数据
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
```
3. 运行测试集并获取预测结果:
```python
# 设置模型为评估模式
model.eval()
# 存储真实标签和预测标签
true_labels = []
pred_labels = []
# 遍历测试集并进行预测
with torch.no_grad():
for images, labels in test_loader:
# 将数据传入模型进行预测
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
# 将真实标签和预测标签添加到列表中
true_labels.extend(labels.numpy())
pred_labels.extend(predicted.numpy())
```
4. 计算混淆矩阵:
```python
# 将列表转换为numpy数组
true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)
# 计算混淆矩阵
cm = confusion_matrix(true_labels, pred_labels)
```
5. 可以根据需要对混淆矩阵进行可视化或进一步分析。
阅读全文