现有好评和差评两个文本文件,我需要对这两个中文文本文件作为训练集,使用sklearn实现贝叶斯分类,区分好评和差评,并给出准确率和结果可视化图,给出示例代码
时间: 2024-02-03 13:11:44 浏览: 137
以下是一个简单的示例代码,用于实现基于sklearn的贝叶斯分类器,区分好评和差评文本文件,并给出准确率和结果可视化图。
首先,需要安装sklearn和matplotlib库。可以使用以下命令进行安装:
```python
!pip install sklearn
!pip install matplotlib
```
接下来,我们可以使用以下代码读取文本文件,并将其转换为sklearn期望的格式:
```python
import os
import jieba
from sklearn.feature_extraction.text import CountVectorizer
# 读取文本文件
def read_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content.strip()
# 对文本进行分词
def cut_words(text):
return ' '.join(jieba.cut(text))
# 读取好评和差评文本文件
positive_path = './positive.txt'
negative_path = './negative.txt'
positive_content = read_file(positive_path)
negative_content = read_file(negative_path)
# 对文本进行分词
positive_words = cut_words(positive_content)
negative_words = cut_words(negative_content)
# 将分词后的文本转换为sklearn期望的格式
vectorizer = CountVectorizer()
train_data = vectorizer.fit_transform([positive_words, negative_words])
train_labels = ['positive', 'negative']
```
接下来,我们可以使用以下代码训练贝叶斯分类器,并使用测试数据计算准确率:
```python
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
# 训练贝叶斯分类器
clf = MultinomialNB()
clf.fit(train_data, train_labels)
# 使用测试数据计算准确率
test_data = vectorizer.transform([' '.join(jieba.cut('这个产品非常好,值得购买。'))])
test_label = 'positive'
predicted_label = clf.predict(test_data)
accuracy = accuracy_score([test_label], predicted_label)
print('Accuracy: {:.2f}'.format(accuracy))
```
最后,我们可以使用以下代码生成结果可视化图:
```python
import numpy as np
import matplotlib.pyplot as plt
# 生成网格点
xx, yy = np.meshgrid(np.linspace(-5, 15, 500), np.linspace(-5, 15, 500))
# 计算每个网格点的预测标签
Z = clf.predict(vectorizer.transform([' '.join(jieba.cut(''.join(text)))) for text in np.c_[xx.ravel(), yy.ravel()]])
Z = Z.reshape(xx.shape)
# 绘制决策边界
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
# 绘制训练数据点
plt.scatter(train_data.toarray()[:1, 0], train_data.toarray()[:1, 1], color='red')
plt.scatter(train_data.toarray()[1:, 0], train_data.toarray()[1:, 1], color='blue')
# 设置图形参数
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.show()
```
完整代码如下:
阅读全文