已知y_pred以及y_test,绘制出混淆矩阵图
时间: 2024-09-25 22:02:06 浏览: 38
在Python中,你可以使用`scikit-learn`库中的`confusion_matrix`函数来计算混淆矩阵,并结合`matplotlib`或`seaborn`库来可视化它。以下是一个简单的例子:
首先,确保已经安装了所需的库:
```bash
pip install scikit-learn matplotlib seaborn
```
然后,你可以按照以下步骤操作:
```python
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# 假设y_pred和y_test是你的预测结果和实际标签,它们应该是同长度的一维数组
y_pred = np.array([...])
y_test = np.array([...])
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 将混淆矩阵转换为pandas DataFrame方便绘图
df_cm = pd.DataFrame(cm, index=np.unique(y_test), columns=np.unique(y_test))
# 使用seaborn绘制热力图
plt.figure(figsize=(10, 7))
sns.heatmap(df_cm, annot=True, cmap='Blues') # 可以选择其他颜色映射
plt.title('Confusion Matrix')
plt.xlabel('Predicted Classes')
plt.ylabel('True Classes')
plt.show()
```
这将生成一个显示预测类别和真实类别之间匹配情况的混淆矩阵图形。
阅读全文