分类问题(二)混淆矩阵,分类问题(二)混淆矩阵,Precision与与Recall
混淆矩阵混淆矩阵
衡量一个分类器性能的更好的办法是混淆矩阵。它基于的思想是:计算类别A被分类为类别B的次数。例如在查看分类器将图
片5分类成图片3时,我们会看混淆矩阵的第5行以及第3列。
为了计算一个混淆矩阵,我们首先需要有一组预测值,之后再可以将它们与标注值(label)进行对比。我们也可以在测试集
上做预测,但是最好是先不要动测试集(测试集仅需要在最后的阶段使用,在我们有了一个准备上线的分类器后,最后再用测
试集测试性能)。接下来,我们可以使用cross_val_predict() 方法:
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
y_train_pred.shape
>(60000,)
与cross_val_score() 方法一样,cross_val_predict() 会执行K-折交叉验证,但是不会返回评估分数,而是返回在每个测试折上
的预测值,加起来就是整个训练数据集的预测值。现在我们可以使用confusion_matricx() 方法获取混淆矩阵。直接传入label
数据(y_train_5)以及预测数据(y_train_pred)即可:
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
>array([[53892, 687],
[ 1891, 3530]])
在这个混淆矩阵中,每一行代表一个真实类别,每一列代表一个预测类别。第一行代表的是“非5”(亦称为negative class)图
片:53892张图片被分类为“非5“类别(它们亦称为true negatives)。剩下的687 张图片被错误的分类为”非5“(亦称为false
positives)。第二行代表的是”真5“(亦称为 positive class):1891张图片被错误地分类为”非5“类别(亦称为false
negatives),剩下的3530 张图片被正确地分类为”真5“(亦称为true positives)。一个完美的分类器应该仅包含true positives
以及true negatives,所以它的混淆矩阵应该仅有主对角线上有非0数值,其他值应都为0。例如,假设我们有了以下一个完美
的预测:
y_train_perfect_predictions = y_train_5
confusion_matrix(y_train_5, y_train_perfect_predictions)
>array([[54579, 0],
[ 0, 5421]])
混淆矩阵可以给我们提供很多信息,但是有时候我们可能需要一个更精准的指标。一个比较好的方式是:查看positive
predictions的精准度。它也称为分类器的精度(precision),它的公式为:
Precision
Precision=TP / (TP + FP)
这里TP 是true positives 的数量,FP 是false positive 的数量。
对于精度,我们仍有办法去构造一个完整精度。比如假设测试集里全部是数字5,然后模型的逻辑是仅输出True。这样就可以
构造一个 100% 精度的模型。所以精度(precision)一般与另一个指标一起用,这另一个指标称为回调(recall),也称为
sensitivity或true positive rate(TPR):它是分类器正确分类positive 条目的比率,公式为:
Recall
TP / (TP+FN)
这里FN是false negatives的数目。
如果对混淆矩阵的这些概念比较模糊的话,可以看看下图:
评论0