对fetch_20newsgroups进行分析,要求: 1)显示数据集的标签,及第一个样本的内容。 2)使用伯努利朴素贝叶斯对数据进行分析,求训练集及测试集的准确率,用classification_report求相关指标值,统计总体运行时间。
时间: 2024-10-27 16:13:58 浏览: 29
首先,`fetch_20newsgroups`是scikit-learn库中的一个函数,用于从互联网上下载20 Newsgroups文本分类数据集,它包含了多个新闻组主题的文档,常用于文本分类的示例。
1)加载数据并查看部分信息:
```python
from sklearn.datasets import fetch_20newsgroups
import pandas as pd
data = fetch_20newsgroups(shuffle=True, random_state=42)
labels = data.target_names
print("Data Set Labels:", labels)
# 显示第一个样本的内容和标签
sample_text = data.data[0]
sample_label = data.target[0]
print(f"First Sample Content:\n{sample_text}\nLabel: {labels[sample_label]}")
```
2)接下来我们使用伯努利朴素贝叶斯模型对数据进行分析:
```python
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import BernoulliNB
from sklearn.model_selection import train_test_split, time
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
vectorizer = CountVectorizer(binary=True) # 使用二进制表示文本特征
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)
start_time = time()
model = BernoulliNB()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
train_accuracy = accuracy_score(y_train, model.predict(X_train))
test_accuracy = accuracy_score(y_test, y_pred)
fit_time = time() - start_time
print(f"Training Accuracy: {train_accuracy}")
print(f"Testing Accuracy: {test_accuracy}")
conf_mat = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)
print("Confusion Matrix:")
print(conf_mat)
print("\nClassification Report:")
print(report)
```
在这个例子中,我们计算了训练集和测试集的准确率,并通过`classification_report`获取了精确度、召回率、F1分数等评估指标。同时,我们还记录了整个模型训练过程的时间。
阅读全文