在已经保存模型的情况下,该如何用代码生成混淆矩阵和ROC曲线,并进行可视化
时间: 2024-02-25 20:58:09 浏览: 138
要生成混淆矩阵和ROC曲线,并进行可视化,你可以使用scikit-learn库中的confusion_matrix和roc_curve函数。下面是一个使用已经保存的模型生成混淆矩阵和ROC曲线的示例代码:
```python
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import joblib
# 加载已经保存的模型
model = joblib.load('model.joblib')
# 加载测试数据
X_test = np.load('X_test.npy')
y_test = np.load('y_test.npy')
# 预测概率
y_pred_proba = model.predict_proba(X_test)
# 预测类别
y_pred = np.argmax(y_pred_proba, axis=1)
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 可视化混淆矩阵
sns.heatmap(cm, annot=True, cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
# 计算ROC曲线和AUC分数
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba[:, 1])
roc_auc = roc_auc_score(y_test, y_pred_proba[:, 1])
# 可视化ROC曲线
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.show()
```
在上面的示例代码中,我们首先加载了已经保存的模型。接着,我们加载了测试数据,并使用predict_proba方法获取模型的预测概率,使用argmax方法获取预测类别。然后,我们使用confusion_matrix函数计算混淆矩阵,并使用heatmap函数将其可视化。最后,我们使用roc_curve函数计算ROC曲线和AUC分数,并使用plot函数将其可视化。
阅读全文