resnet18进行多分类任务,已知预测后的y_pred,y_true和y_pro(都为list形式),如何绘制多分类的prc曲线、roc曲线和混淆矩阵?给出python代码
时间: 2024-09-13 20:16:35 浏览: 34
要绘制多分类的P-R曲线(Precision-Recall Curve)、ROC曲线(Receiver Operating Characteristic Curve)和混淆矩阵,你需要使用一些Python库,如`matplotlib`用于绘图,`sklearn`中的`precision_recall_curve`、`roc_curve`、`confusion_matrix`等函数来计算必要的指标。这里提供一个基本的代码示例来指导你完成这个任务。
首先,确保你已经有了`y_pred`(预测概率)、`y_true`(真实标签)和`y_pro`(预测类别)。`y_pred`应该是一个二维数组,每个类别的预测概率对应一个列;`y_true`和`y_pro`是一维数组,每个元素代表一个样本的真实类别或预测类别。
以下是一个示例代码:
```python
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, roc_curve, auc, confusion_matrix, precision_score, recall_score
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
import numpy as np
# 假设你有三个类别
n_classes = 3
# 将真实值和预测值二值化
y_true_bin = label_binarize(y_true, classes=np.arange(n_classes))
y_pred_bin = label_binarize(y_pro, classes=np.arange(n_classes))
# 计算每个类别的P-R曲线的精确度和召回率
precision = dict()
recall = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_pred[:, i])
plt.plot(recall[i], precision[i], lw=2, label='class {0} (area = {1:0.2f})'.format(i, auc(recall[i], precision[i])))
# 绘制P-R曲线
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Multi-class Precision-Recall curve')
plt.legend(loc="best")
plt.show()
# 计算ROC曲线和AUC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# 绘制ROC曲线
for i in range(n_classes):
plt.plot(fpr[i], tpr[i], lw=2, label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()
# 计算混淆矩阵
conf_matrix = confusion_matrix(y_true, y_pro)
# 绘制混淆矩阵
plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = 'd'
thresh = conf_matrix.max() / 2.
for i, j in itertools.product(range(conf_matrix.shape[0]), range(conf_matrix.shape[1])):
plt.text(j, i, format(conf_matrix[i, j], fmt),
horizontalalignment="center",
color="white" if conf_matrix[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
```
注意:这段代码假定你已经有了`y_pred`和`y_true`以及`y_pro`,并且它们分别是预测概率、真实标签和预测类别的列表。如果你的数据是其他形式的,你可能需要进行适当的转换。同时,混淆矩阵的代码中`classes`变量需要根据实际类别进行定义,此代码段并没有定义`classes`变量,只是给出了一个大致的绘图方式。
阅读全文