应用sklearn中的朴贝叶斯分类函数对垃圾短信数据集进行分类和测试,计算在测试集上的准确率、召回率和精准率,绘制混淆矩阵。
时间: 2024-09-24 07:23:39 浏览: 95
在Python中,我们可以使用scikit-learn库(sklearn)中的朴素贝叶斯分类器(如GaussianNB或MultinomialNB)对垃圾短信数据集进行分类。首先,你需要准备数据集,将其划分为训练集和测试集。这里是一个简单的步骤概述:
1. **数据预处理**:加载数据集并对其进行预处理,包括文本清洗(移除停用词、标点等)、分词和向量化。
```python
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
# 假设你有一个名为sms_data的数据框,包含'review'列的文本数据和'label'列的目标标签
X = sms_data['review']
y = sms_data['label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
vectorizer = CountVectorizer()
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
```
2. **模型训练**:使用训练集数据训练朴素贝叶斯分类器。
```python
from sklearn.naive_bayes import MultinomialNB
nb_classifier = MultinomialNB()
nb_classifier.fit(X_train_vec, y_train)
```
3. **预测及评估**:在测试集上进行预测,并计算准确率、召回率和精确率。
```python
from sklearn.metrics import accuracy_score, recall_score, precision_score, confusion_matrix
y_pred = nb_classifier.predict(X_test_vec)
accuracy = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred, pos_label='spam') # 假设 spam 是垃圾邮件类别
precision = precision_score(y_test, y_pred, pos_label='spam')
confusion_matrix = confusion_matrix(y_test, y_pred)
```
4. **混淆矩阵可视化**:使用matplotlib或seaborn库绘制混淆矩阵。
```python
import matplotlib.pyplot as plt
import seaborn as sns
sns.heatmap(confusion_matrix, annot=True, fmt="d")
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix for Spam Detection')
plt.show()
```
阅读全文