pytorch 混淆矩阵代码
时间: 2023-09-24 13:07:49 浏览: 125
confusion_混淆矩阵、pytorch、模型_混淆矩阵pytorch_混淆矩阵_
5星 · 资源好评率100%
以下是一个使用PyTorch计算混淆矩阵的示例代码:
```python
import torch
from sklearn.metrics import confusion_matrix
# 假设你有两个类别
num_classes = 2
# 初始化混淆矩阵
confusion_matrix = torch.zeros(num_classes, num_classes)
# 假设你有一批预测标签和真实标签
preds = torch.tensor([0, 1, 1, 0, 1])
targets = torch.tensor([0, 1, 0, 0, 1])
# 计算混淆矩阵
for t, p in zip(targets.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
print(confusion_matrix)
```
这段代码首先创建了一个全零的混淆矩阵,然后使用循环遍历预测标签和真实标签,将每个对应位置的计数加1。最后,打印出混淆矩阵。
需要注意的是,在这个示例中,我们使用了PyTorch和sklearn库。PyTorch用于创建张量和计算,而sklearn库用于计算混淆矩阵。你可以使用pip或conda来安装这些库。
阅读全文