朴素贝叶斯分类算法py
时间: 2024-01-04 20:03:48 浏览: 19
以下是一个简单的朴素贝叶斯分类器的 Python 实现。该分类器用于基于文本分类问题。
```python
import math
class NaiveBayesClassifier:
def __init__(self):
self.classes = []
self.class_word_count = {}
self.class_total_count = {}
self.vocabulary = set()
def train(self, train_data):
for class_name, text in train_data:
if class_name not in self.classes:
self.classes.append(class_name)
if class_name not in self.class_word_count:
self.class_word_count[class_name] = {}
if class_name not in self.class_total_count:
self.class_total_count[class_name] = 0
for word in text.split():
if word not in self.vocabulary:
self.vocabulary.add(word)
if word not in self.class_word_count[class_name]:
self.class_word_count[class_name][word] = 0
self.class_word_count[class_name][word] += 1
self.class_total_count[class_name] += 1
def classify(self, text):
max_score = -math.inf
max_class = None
for class_name in self.classes:
score = math.log(self.class_total_count[class_name])
for word in text.split():
if word in self.class_word_count[class_name]:
word_count = self.class_word_count[class_name][word]
else:
word_count = 0
score += math.log((word_count + 1) / (self.class_total_count[class_name] + len(self.vocabulary)))
if score > max_score:
max_score = score
max_class = class_name
return max_class
```
使用示例:
```python
train_data = [
('spam', 'buy cheap viagra now'),
('spam', 'make money fast'),
('ham', 'hello world'),
('ham', 'python is awesome')
]
classifier = NaiveBayesClassifier()
classifier.train(train_data)
print(classifier.classify('buy now, make money'))
# Output: spam
print(classifier.classify('hello python world'))
# Output: ham
```
在该示例中,我们定义了一个简单的训练数据集,其中第一个元素表示类别,第二个元素表示文本。然后使用 `NaiveBayesClassifier` 类训练分类器,最后使用 `classify` 方法对新的文本进行分类。