用C++写一个朴素贝叶斯算法
时间: 2023-08-04 20:28:17 浏览: 83
好的,以下是一个简单的朴素贝叶斯算法的C++实现,用于文本分类:
```c++
#include <iostream>
#include <fstream>
#include <vector>
#include <map>
#include <string>
#include <cmath>
using namespace std;
// 停用词
set<string> stopWords = {"the", "and", "a", "an", "in", "is", "it", "of", "to"};
// 分词函数
vector<string> tokenize(const string& text) {
vector<string> tokens;
string token;
for (char c : text) {
if (isalpha(c)) {
token += tolower(c);
} else {
if (!token.empty()) {
if (stopWords.count(token) == 0) {
tokens.push_back(token);
}
token.clear();
}
}
}
if (!token.empty()) {
if (stopWords.count(token) == 0) {
tokens.push_back(token);
}
}
return tokens;
}
// 计算每个词的出现次数
map<string, int> wordCounts(const vector<string>& tokens) {
map<string, int> counts;
for (const string& token : tokens) {
counts[token] += 1;
}
return counts;
}
// 训练函数
void train(const string& filename, map<string, map<string, int>>& wordCountsByClass, map<string, int>& docCountsByClass) {
ifstream infile(filename);
if (!infile) {
cerr << "Error: could not open file " << filename << endl;
exit(1);
}
string line;
while (getline(infile, line)) {
vector<string> tokens = tokenize(line);
if (tokens.empty()) {
continue;
}
string cls = tokens[0];
docCountsByClass[cls] += 1;
tokens.erase(tokens.begin());
map<string, int>& counts = wordCountsByClass[cls];
for (const pair<string, int>& p : wordCounts(tokens)) {
counts[p.first] += p.second;
}
}
}
// 预测函数
string predict(const string& text, const map<string, map<string, int>>& wordCountsByClass, const map<string, int>& docCountsByClass) {
vector<string> tokens = tokenize(text);
if (tokens.empty()) {
return "unknown";
}
double maxLogProb = -numeric_limits<double>::infinity();
string argMaxClass;
for (const pair<string, int>& p : docCountsByClass) {
const string& cls = p.first;
int docsInClass = p.second;
double logProb = log(docsInClass) - log(docCountsByClass.size());
const map<string, int>& counts = wordCountsByClass.at(cls);
for (const string& token : tokens) {
int tokenCount = counts.count(token) == 0 ? 0 : counts.at(token);
logProb += log(tokenCount + 1) - log(counts.size() + 1);
}
if (logProb > maxLogProb) {
maxLogProb = logProb;
argMaxClass = cls;
}
}
return argMaxClass;
}
int main() {
// 训练集
map<string, map<string, int>> wordCountsByClass;
map<string, int> docCountsByClass;
train("train.txt", wordCountsByClass, docCountsByClass);
// 测试集
ifstream testfile("test.txt");
if (!testfile) {
cerr << "Error: could not open file test.txt" << endl;
exit(1);
}
string line;
while (getline(testfile, line)) {
string cls = predict(line, wordCountsByClass, docCountsByClass);
cout << line << " => " << cls << endl;
}
return 0;
}
```
上述代码的训练数据和测试数据分别在train.txt和test.txt文件中。其中,训练数据的格式如下:
```
class1 sentence1
class1 sentence2
class2 sentence3
class2 sentence4
...
```
测试数据的格式与训练数据相同,只不过没有类别标签。代码中的分词函数和停用词可以根据实际情况进行调整。
阅读全文