续写代码,利用已有变量对模型绘制roc曲线
时间: 2024-09-18 16:09:15 浏览: 35
为了绘制ROC曲线,我们需要从预测结果中获取每个样本属于正类的概率,并使用这些概率与真实的标签来计算假正率(False Positive Rate, FPR)和真正率(True Positive Rate, TPR),最后将FPR作为x轴,TPR作为y轴绘制曲线。由于原始代码中`SpamDetector`类只提供了类别预测,我们假设需要添加方法来输出属于正例(垃圾邮件)的概率。
```python
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# 增加一个方法用于返回预测正例的概率
def predict_proba(self, X):
results = []
for x in X:
counts = self.get_word_counts(self.tokenize(x))
spam_score = self.log_class_priors['spam']
for word, _ in counts.items():
if word in self.vocab:
if word in self.word_counts['spam']:
spam_score += math.log((self.word_counts['spam'][word] + 1) /
(sum(self.word_counts['spam'].values()) + len(self.vocab)))
else:
spam_score += math.log(1 / (sum(self.word_counts['spam'].values()) + len(self.vocab)))
# 归一化处理得到属于正类的概率
spam_prob = math.exp(spam_score - max(spam_score, self.log_class_priors['ham']))
results.append(spam_prob)
return results
# 添加predict_proba方法到SpamDetector类
SpamDetector.predict_proba = predict_proba
# 使用新添加的方法获取预测概率
probas_ = MNB.predict_proba(X[:100])
# 根据真实标签和预测概率计算fpr和tpr
fpr, tpr, thresholds = roc_curve(true, probas_)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
```
以上代码首先通过修改`SpamDetector`类以提供预测正例概率的功能,然后使用这个功能计算出测试集中每条数据被标记为垃圾邮件的概率,并基于此概率与实际标签计算出了FPR和TPR值。最后,使用matplotlib库绘图函数画出了ROC曲线并显示了AUC值。请注意,对于生产环境下的应用,在绘制ROC曲线时应考虑更多的细节和优化方案。
阅读全文