pytorch画出混淆矩阵
时间: 2023-08-06 15:07:25 浏览: 119
要使用PyTorch绘制混淆矩阵,你可以按照以下步骤进行操作:
首先,确保你已经导入了需要的库和模块:
```python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
```
然后,定义一个函数来绘制混淆矩阵:
```python
def plot_confusion_matrix(confusion_matrix, labels):
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
```
接下来,你需要计算混淆矩阵。假设你有模型的预测结果和真实标签,你可以使用`sklearn.metrics`库中的`confusion_matrix`函数来计算混淆矩阵。这个函数接受两个参数:真实标签和预测标签。
```python
from sklearn.metrics import confusion_matrix
# 假设predictions是模型的预测结果,labels是真实标签
cm = confusion_matrix(labels, predictions)
```
最后,调用`plot_confusion_matrix`函数来绘制混淆矩阵:
```python
# 假设labels是类别的列表
plot_confusion_matrix(cm, labels)
```
这样就可以绘制出混淆矩阵了。注意,你需要将类别的列表传递给`plot_confusion_matrix`函数,以确保正确的类别标签显示在矩阵的横纵轴上。
希望对你有所帮助!
阅读全文