用numpy库写一个基于朴素贝叶斯的垃圾邮件分类
时间: 2023-05-13 22:07:22 浏览: 107
可以使用以下代码实现基于朴素贝叶斯的垃圾邮件分类:
```
import numpy as np
class NaiveBayes:
def __init__(self):
self.vocab = set()
self.labels = set()
self.label_word_count = {}
self.label_doc_count = {}
self.total_doc_count = 0
def train(self, X, y):
self.labels = set(y)
self.vocab = set([word for doc in X for word in doc])
self.label_word_count = {label: np.zeros(len(self.vocab)) for label in self.labels}
self.label_doc_count = {label: 0 for label in self.labels}
self.total_doc_count = len(X)
for i in range(len(X)):
label = y[i]
self.label_doc_count[label] += 1
for word in X[i]:
self.label_word_count[label][list(self.vocab).index(word)] += 1
def predict(self, X):
y_pred = []
for doc in X:
scores = {label: np.log(self.label_doc_count[label] / self.total_doc_count) for label in self.labels}
for word in doc:
if word in self.vocab:
for label in self.labels:
word_count = self.label_word_count[label][list(self.vocab).index(word)]
total_count = np.sum(self.label_word_count[label])
scores[label] += np.log((word_count + 1) / (total_count + len(self.vocab)))
y_pred.append(max(scores, key=scores.get))
return y_pred
```
其中,X是一个列表,每个元素是一个文档,表示为一个单词列表;y是一个列表,每个元素是对应文档的类别标签。可以使用以下代码进行训练和预测:
```
X_train = [['hello', 'world', 'spam'], ['spam', 'eggs', 'spam'], ['world', 'world', 'hello']]
y_train = ['spam', 'spam', 'ham']
X_test = [['hello', 'world'], ['spam', 'spam', 'spam']]
nb = NaiveBayes()
nb.train(X_train, y_train)
y_pred = nb.predict(X_test)
print(y_pred) # ['ham', 'spam']
```
这里使用了 Laplace 平滑,避免出现概率为 0 的情况。
阅读全文