混淆矩阵pytorch实现代码
时间: 2023-08-28 21:50:25 浏览: 222
混淆矩阵
以下是使用PyTorch实现混淆矩阵的代码示例:
```python
import torch
from sklearn.metrics import confusion_matrix
# 假设有两个类别
num_classes = 2
# 随机生成预测结果和真实标签
y_pred = torch.randint(num_classes, (100,))
y_true = torch.randint(num_classes, (100,))
# 将Tensor转换为NumPy数组
y_pred_np = y_pred.numpy()
y_true_np = y_true.numpy()
# 使用sklearn计算混淆矩阵
cm = confusion_matrix(y_true_np, y_pred_np)
print(cm)
```
上述代码首先使用torch.randint()函数生成随机的预测结果和真实标签,然后将它们转换为NumPy数组。最后,使用sklearn.metrics.confusion_matrix()函数计算混淆矩阵并打印结果。
请注意,这里使用了sklearn库中的混淆矩阵函数,而不是直接使用PyTorch实现。这是因为sklearn提供了更简单、更灵活的混淆矩阵计算方式。
阅读全文