实现决策树的c++代码
时间: 2023-10-14 10:06:35 浏览: 98
决策树C++代码示例
以下是一个简单的决策树实现的C++代码示例:
```
#include<iostream>
#include<vector>
using namespace std;
// 决策树节点
class DecisionTreeNode {
public:
DecisionTreeNode() : feature(-1), label(-1) {} // 初始化特征和标签为-1
int feature; // 节点的特征
int label; // 节点的标签
vector<int> children; // 子节点
};
// 决策树
class DecisionTree {
public:
DecisionTree() : root(0) {} // 初始化根节点为0
void Train(const vector<vector<int>>& data, const vector<int>& labels); // 训练决策树
int Predict(const vector<int>& data); // 预测数据的标签
private:
int BuildTree(const vector<vector<int>>& data, const vector<int>& labels); // 构建决策树
int root; // 根节点
vector<DecisionTreeNode> tree; // 决策树节点集合
};
// 训练决策树
void DecisionTree::Train(const vector<vector<int>>& data, const vector<int>& labels) {
root = BuildTree(data, labels);
}
// 预测数据的标签
int DecisionTree::Predict(const vector<int>& data) {
int node = root;
while (tree[node].label == -1) { // 当前节点不是叶子节点
int feature = tree[node].feature;
if (data[feature] == 0) { // 根据数据特征判断下一个节点
node = tree[node].children[0];
} else {
node = tree[node].children[1];
}
}
return tree[node].label; // 返回叶子节点的标签
}
// 构建决策树
int DecisionTree::BuildTree(const vector<vector<int>>& data, const vector<int>& labels) {
int n = data.size(), m = data[0].size();
if (n == 0) { // 数据为空
DecisionTreeNode node;
node.label = -1;
return tree.size();
}
int cnt0 = 0, cnt1 = 0;
for (int i = 0; i < n; i++) { // 统计标签数量
if (labels[i] == 0) {
cnt0++;
} else {
cnt1++;
}
}
if (cnt0 == n) { // 所有数据标签都为0
DecisionTreeNode node;
node.label = 0;
return tree.size();
}
if (cnt1 == n) { // 所有数据标签都为1
DecisionTreeNode node;
node.label = 1;
return tree.size();
}
int bestFeature = -1, bestFeatureSplit = -1;
double bestGiniIndex = 1e9; // 初始化最小的Gini指数为一个较大的值
for (int i = 0; i < m; i++) { // 计算每个特征的Gini指数
int cnt00 = 0, cnt01 = 0, cnt10 = 0, cnt11 = 0;
for (int j = 0; j < n; j++) {
if (data[j][i] == 0 && labels[j] == 0) {
cnt00++;
} else if (data[j][i] == 0 && labels[j] == 1) {
cnt01++;
} else if (data[j][i] == 1 && labels[j] == 0) {
cnt10++;
} else {
cnt11++;
}
}
double giniIndex = 1.0 - pow(double(cnt00 + cnt01) / n, 2) - pow(double(cnt10 + cnt11) / n, 2);
if (cnt00 + cnt01 > 0) {
giniIndex -= pow(double(cnt00) / (cnt00 + cnt01), 2) + pow(double(cnt01) / (cnt00 + cnt01), 2);
}
if (cnt10 + cnt11 > 0) {
giniIndex -= pow(double(cnt10) / (cnt10 + cnt11), 2) + pow(double(cnt11) / (cnt10 + cnt11), 2);
}
if (giniIndex < bestGiniIndex) { // 更新最小的Gini指数和最优特征
bestGiniIndex = giniIndex;
bestFeature = i;
if (cnt00 + cnt10 > cnt01 + cnt11) {
bestFeatureSplit = 0;
} else {
bestFeatureSplit = 1;
}
}
}
DecisionTreeNode node;
node.feature = bestFeature;
node.label = -1;
tree.push_back(node);
vector<vector<int>> dataSplit[2];
vector<int> labelsSplit[2];
for (int i = 0; i < n; i++) { // 按最优特征划分数据集
if (data[i][bestFeature] == bestFeatureSplit) {
dataSplit[0].push_back(data[i]);
labelsSplit[0].push_back(labels[i]);
} else {
dataSplit[1].push_back(data[i]);
labelsSplit[1].push_back(labels[i]);
}
}
node.children.push_back(BuildTree(dataSplit[0], labelsSplit[0])); // 递归构建子树
node.children.push_back(BuildTree(dataSplit[1], labelsSplit[1]));
tree.back() = node;
return tree.size() - 1;
}
// 测试
int main() {
vector<vector<int>> data{{0, 0}, {0, 1}, {1, 0}, {1, 1}};
vector<int> labels{0, 1, 1, 0};
DecisionTree dt;
dt.Train(data, labels);
cout << dt.Predict(vector<int>{0, 0}) << endl; // 输出0
cout << dt.Predict(vector<int>{0, 1}) << endl; // 输出1
cout << dt.Predict(vector<int>{1, 0}) << endl; // 输出1
cout << dt.Predict(vector<int>{1, 1}) << endl; // 输出0
return 0;
}
```
该代码实现了一个简单的二分类决策树,通过Gini指数选择最优特征,使用递归构建决策树,最终可以对新的数据进行分类预测。
阅读全文