朴素贝叶斯垃圾短信分类代码
时间: 2023-09-12 14:06:09 浏览: 94
以下是一个简单的朴素贝叶斯垃圾短信分类代码:
``` python
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
# 读取数据
data = pd.read_csv('spam.csv', encoding='latin-1')
# 保留有用的列
data = data[['v1', 'v2']]
data = data.rename(columns={'v1': 'label', 'v2': 'text'})
# 将标签转换为二进制
data['label'] = np.where(data['label'] == 'spam', 1, 0)
# 数据清理
def clean_text(text):
text = re.sub('[^a-zA-Z]', ' ', text)
text = text.lower()
text = text.split()
text = ' '.join(text)
return text
data['text'] = data['text'].apply(clean_text)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data['text'], data['label'], test_size=0.2, random_state=42)
# 特征提取
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(X_train)
X_test = vectorizer.transform(X_test)
# 建立模型
model = MultinomialNB()
model.fit(X_train, y_train)
# 预测并计算准确率
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: {:.2f}%".format(accuracy * 100))
```
其中,`spam.csv` 文件是一个包含垃圾短信和非垃圾短信的数据集。代码首先读取数据,并将标签转换为二进制。然后进行数据清理,去除标点符号、转换为小写字母等。接着使用 `train_test_split` 函数将数据集划分为训练集和测试集。使用 `CountVectorizer` 对文本进行特征提取,并使用 `MultinomialNB` 来建立模型。最后,使用预测结果计算准确率。
阅读全文