pytorch混淆矩阵
时间: 2023-08-06 12:07:27 浏览: 232
混淆矩阵及其可视化-精通开关电源设计 第2版 [(美)马尼克塔拉 著] 2015年 中文版
PyTorch中的混淆矩阵是用于评估分类模型性能的常见工具。它是一个正方形的矩阵,其中行表示实际类别,列表示预测类别。混淆矩阵的对角线上的元素表示正确分类的样本数量,而其他元素表示错误分类的样本数量。
在PyTorch中,可以使用sklearn库中的`confusion_matrix`函数来计算混淆矩阵。以下是一个示例代码:
```python
import torch
from sklearn.metrics import confusion_matrix
# 假设有真实标签和预测标签
true_labels = [0, 1, 2, 1, 0]
predicted_labels = [0, 2, 2, 1, 0]
# 将标签转换为Tensor
true_tensor = torch.tensor(true_labels)
predicted_tensor = torch.tensor(predicted_labels)
# 计算混淆矩阵
confusion = confusion_matrix(true_tensor, predicted_tensor)
print(confusion)
```
运行此代码将输出混淆矩阵:
```
[[2 0 0]
[0 2 0]
[0 1 0]]
```
这个混淆矩阵表示了3个类别的分类结果。对角线上的元素2表示正确分类的数量,其他元素表示错误分类的数量。
阅读全文