混淆矩阵pytorch代码
时间: 2023-12-25 19:30:34 浏览: 133
混淆矩阵是用于评估分类模型性能的一种常用工具。在PyTorch中,可以使用以下代码计算混淆矩阵:
```python
import torch
import numpy as np
def confusion_matrix(preds, labels, num_classes):
conf_matrix = torch.zeros(num_classes, num_classes)
preds = torch.argmax(preds, 1)
for p, t in zip(preds, labels):
conf_matrix[p, t] += 1
return conf_matrix
# 示例用法
preds = torch.tensor([0, 1, 2, 1, 0]) # 模型预测结果
labels = torch.tensor([0, 1, 2, 2, 1]) # 真实标签
num_classes = 3 # 类别数
conf_matrix = confusion_matrix(preds, labels, num_classes)
print(conf_matrix)
```
这段代码定义了一个`confusion_matrix`函数,它接受模型的预测结果`preds`、真实标签`labels`和类别数`num_classes`作为输入,并返回一个大小为`num_classes x num_classes`的混淆矩阵。
在示例中,我们使用了一个简单的示例来演示代码的用法。`preds`是模型的预测结果,`labels`是真实标签,`num_classes`是类别数。运行代码后,将输出计算得到的混淆矩阵。
阅读全文
相关推荐
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![txt](https://img-home.csdnimg.cn/images/20241231045021.png)
![py](https://img-home.csdnimg.cn/images/20250102104920.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)