朴素贝叶斯垃圾邮件文档识别代码
时间: 2023-09-28 16:11:59 浏览: 58
以下是朴素贝叶斯垃圾邮件文档识别的 Python 代码示例:
```python
import os
import io
import string
from collections import Counter
from math import log
# 加载数据集
def load_dataset(path):
data = []
for root, dirs, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
with io.open(file_path, 'r', encoding='utf-8') as f:
data.append((file, f.read()))
return data
# 数据预处理
def preprocess(text):
# 小写化
text = text.lower()
# 去除标点符号
text = text.translate(str.maketrans('', '', string.punctuation))
# 去除数字
text = text.translate(str.maketrans('', '', string.digits))
# 分词
words = text.split()
# 去除停用词
stopwords = set(['a', 'an', 'the'])
words = [word for word in words if word not in stopwords]
# 统计词频
word_count = Counter(words)
return word_count
# 训练模型
def train(dataset):
spam_word_count = Counter()
ham_word_count = Counter()
spam_doc_count = 0
ham_doc_count = 0
for label, text in dataset:
if label.startswith('spam'):
spam_word_count += preprocess(text)
spam_doc_count += 1
else:
ham_word_count += preprocess(text)
ham_doc_count += 1
spam_word_prob = {}
ham_word_prob = {}
for word, count in spam_word_count.items():
spam_word_prob[word] = log((count + 1) / (sum(spam_word_count.values()) + len(spam_word_count)))
for word, count in ham_word_count.items():
ham_word_prob[word] = log((count + 1) / (sum(ham_word_count.values()) + len(ham_word_count)))
spam_doc_prob = log(spam_doc_count / len(dataset))
ham_doc_prob = log(ham_doc_count / len(dataset))
return spam_word_prob, ham_word_prob, spam_doc_prob, ham_doc_prob
# 预测
def predict(text, spam_word_prob, ham_word_prob, spam_doc_prob, ham_doc_prob):
word_count = preprocess(text)
spam_prob = spam_doc_prob
ham_prob = ham_doc_prob
for word, count in word_count.items():
if word in spam_word_prob:
spam_prob += spam_word_prob[word] * count
if word in ham_word_prob:
ham_prob += ham_word_prob[word] * count
if spam_prob > ham_prob:
return 'spam'
else:
return 'ham'
# 测试
def test(dataset, spam_word_prob, ham_word_prob, spam_doc_prob, ham_doc_prob):
tp = 0
tn = 0
fp = 0
fn = 0
for label, text in dataset:
predict_label = predict(text, spam_word_prob, ham_word_prob, spam_doc_prob, ham_doc_prob)
if label.startswith('spam'):
if predict_label == 'spam':
tp += 1
else:
fn += 1
else:
if predict_label == 'spam':
fp += 1
else:
tn += 1
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
accuracy = (tp + tn) / (tp + tn + fp + fn)
print('precision:', precision)
print('recall:', recall)
print('f1 score:', f1)
print('accuracy:', accuracy)
if __name__ == '__main__':
dataset = load_dataset('data')
spam_word_prob, ham_word_prob, spam_doc_prob, ham_doc_prob = train(dataset)
test(dataset, spam_word_prob, ham_word_prob, spam_doc_prob, ham_doc_prob)
```
其中,`load_dataset` 函数用于加载数据集,`preprocess` 函数用于数据预处理,包括小写化、去除标点符号、去除数字、分词和去除停用词等操作,`train` 函数用于训练模型,包括统计垃圾邮件和非垃圾邮件的词频、计算词的条件概率和文档的先验概率,`predict` 函数用于预测文档的标签,即垃圾邮件或非垃圾邮件,`test` 函数用于测试模型的准确率、精确率、召回率和 F1 值等性能指标。
阅读全文