用c++写一个决策树代码
时间: 2023-06-15 14:04:48 浏览: 233
decision-makingTree:这是c++写的决策树
以下是一个简单的决策树的C++代码示例:
```c++
#include <iostream>
#include <vector>
using namespace std;
// 节点结构体
struct Node {
bool is_leaf; // 是否为叶子节点
int feature; // 特征索引
vector<Node*> children; // 子节点
int label; // 叶子节点的类别标签
};
// 训练决策树
Node* train(vector<vector<int>>& data, vector<int>& labels, vector<int>& features) {
Node* node = new Node;
int num_samples = data.size();
int num_features = features.size();
// 计算样本中各个类别的数量
vector<int> label_counts(2);
for (int i = 0; i < num_samples; i++) {
label_counts[labels[i]]++;
}
// 如果所有样本都属于同一类别,返回叶子节点
if (label_counts[0] == num_samples) {
node->is_leaf = true;
node->label = 0;
return node;
}
if (label_counts[1] == num_samples) {
node->is_leaf = true;
node->label = 1;
return node;
}
// 如果没有剩余特征可用,返回叶子节点
if (num_features == 0) {
node->is_leaf = true;
node->label = label_counts[0] > label_counts[1] ? 0 : 1;
return node;
}
// 选择最佳特征
int best_feature_index;
double best_information_gain = -1;
for (int i = 0; i < num_features; i++) {
int feature_index = features[i];
vector<int> left_label_counts(2), right_label_counts(2);
int left_count = 0, right_count = 0;
for (int j = 0; j < num_samples; j++) {
if (data[j][feature_index] == 0) {
left_label_counts[labels[j]]++;
left_count++;
} else {
right_label_counts[labels[j]]++;
right_count++;
}
}
double left_entropy = 0, right_entropy = 0;
if (left_count > 0) {
double left_p0 = (double)left_label_counts[0] / left_count;
double left_p1 = (double)left_label_counts[1] / left_count;
if (left_p0 > 0) {
left_entropy -= left_p0 * log2(left_p0);
}
if (left_p1 > 0) {
left_entropy -= left_p1 * log2(left_p1);
}
}
if (right_count > 0) {
double right_p0 = (double)right_label_counts[0] / right_count;
double right_p1 = (double)right_label_counts[1] / right_count;
if (right_p0 > 0) {
right_entropy -= right_p0 * log2(right_p0);
}
if (right_p1 > 0) {
right_entropy -= right_p1 * log2(right_p1);
}
}
double information_gain = ((double)left_count / num_samples * left_entropy) + ((double)right_count / num_samples * right_entropy);
if (information_gain > best_information_gain) {
best_information_gain = information_gain;
best_feature_index = feature_index;
}
}
// 创建非叶子节点
node->is_leaf = false;
node->feature = best_feature_index;
vector<int> left_data_indices, right_data_indices;
for (int i = 0; i < num_samples; i++) {
if (data[i][best_feature_index] == 0) {
left_data_indices.push_back(i);
} else {
right_data_indices.push_back(i);
}
}
vector<int> left_features, right_features;
for (int i = 0; i < num_features; i++) {
if (features[i] != best_feature_index) {
left_features.push_back(features[i]);
right_features.push_back(features[i]);
}
}
node->children.push_back(train(data, labels, left_features));
node->children.push_back(train(data, labels, right_features));
return node;
}
// 预测
int predict(Node* node, vector<int>& sample) {
if (node->is_leaf) {
return node->label;
}
int feature_index = node->feature;
if (sample[feature_index] == 0) {
return predict(node->children[0], sample);
} else {
return predict(node->children[1], sample);
}
}
int main() {
// 构造训练数据
vector<vector<int>> data = {
{1, 0, 0},
{1, 0, 1},
{0, 1, 0},
{0, 1, 1},
{1, 1, 0},
{1, 1, 1},
{0, 0, 0},
{0, 0, 1}
};
vector<int> labels = {0, 1, 1, 0, 1, 1, 0, 1};
vector<int> features = {0, 1, 2};
// 训练决策树
Node* root = train(data, labels, features);
// 预测
vector<int> sample = {1, 0, 1};
int prediction = predict(root, sample);
cout << "Prediction: " << prediction << endl;
return 0;
}
```
以上代码实现了一个简单的二分类决策树,训练数据是一个3维特征的8个样本,其中前两维特征是布尔类型,最后一维特征是整数类型。特征选择采用信息增益法,分类标准是样本所属的类别。预测时,给定一个样本,从根节点开始遍历,根据特征值选择左/右子树,直到遇到叶子节点,返回类别标签。
阅读全文