pytorch计算混淆矩阵
时间: 2023-09-05 22:13:22 浏览: 36
在PyTorch中,可以通过混淆矩阵(confusion matrix)来评估模型的分类效果。下面是一个计算混淆矩阵的示例代码:
```python
import torch
from sklearn.metrics import confusion_matrix
# 创建模型和数据
model = MyModel()
data_loader = DataLoader(dataset)
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 运行模型并预测
y_true = []
y_pred = []
model.eval()
with torch.no_grad():
for data in data_loader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
y_true += labels.tolist()
y_pred += predicted.tolist()
# 计算混淆矩阵
confusion_mat = confusion_matrix(y_true, y_pred)
print(confusion_mat)
```
在上面的代码中,我们首先创建了一个模型和一个数据加载器。然后,我们将模型移动到设备(如果有GPU则使用GPU)。接下来,我们运行模型并预测每个样本的类别,并将真实标签和预测标签存储在`y_true`和`y_pred`列表中。最后,我们使用`confusion_matrix`函数计算混淆矩阵,并输出结果。
需要注意的是,`confusion_matrix`函数来自于`sklearn.metrics`模块,因此需要先安装`scikit-learn`库。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)