用c++写一个西瓜数据集的决策树
时间: 2024-05-14 18:19:11 浏览: 12
好的,下面是一个简单的C++代码实现西瓜数据集的决策树:
```c++
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <math.h>
using namespace std;
// 定义数据结构
struct Sample {
int index;
bool color; // 0: green, 1: yellow
double root; // 0: curl, 1: little-curl, 2: straight
double knocks;
bool texture; // 0: clear, 1: a little blurriness
bool navel; // 0: yes, 1: no
double touch; // 0: hard, 1: a little soft, 2: soft
bool good; // 0: bad, 1: good
};
// 定义数据集
vector<Sample> dataset = {
{1, 0, 0, 0.697, 0, 0, 0.460, 0},
{2, 0, 0, 0.774, 0, 0, 0.376, 0},
{3, 0, 0, 0.634, 0, 0, 0.264, 0},
{4, 0, 0, 0.608, 0, 0, 0.318, 0},
{5, 0, 0, 0.556, 0, 0, 0.215, 0},
{6, 0, 1, 0.403, 0, 0, 0.237, 0},
{7, 0, 1, 0.481, 0, 0, 0.149, 0},
{8, 0, 1, 0.437, 0, 0, 0.211, 0},
{9, 0, 1, 0.666, 0, 0, 0.091, 1},
{10, 0, 1, 0.243, 0, 0, 0.267, 1},
{11, 1, 0, 0.245, 0, 1, 0.057, 1},
{12, 1, 0, 0.343, 0, 1, 0.099, 1},
{13, 1, 0, 0.639, 0, 1, 0.161, 1},
{14, 1, 0, 0.657, 0, 1, 0.198, 1},
{15, 1, 0, 0.360, 0, 1, 0.370, 1},
{16, 1, 1, 0.593, 0, 1, 0.042, 1},
{17, 1, 1, 0.719, 0, 1, 0.103, 1},
{18, 1, 1, 0.359, 0, 1, 0.188, 1},
{19, 1, 1, 0.339, 0, 1, 0.241, 1},
{20, 1, 1, 0.282, 0, 1, 0.257, 1},
{21, 1, 1, 0.748, 1, 0, 0.232, 0},
{22, 1, 1, 0.714, 1, 0, 0.346, 0},
{23, 1, 1, 0.483, 1, 0, 0.312, 0},
{24, 1, 1, 0.478, 1, 0, 0.437, 0},
{25, 1, 1, 0.525, 1, 0, 0.369, 0},
{26, 1, 1, 0.751, 1, 1, 0.315, 0},
{27, 1, 0, 0.532, 1, 1, 0.253, 0},
{28, 1, 0, 0.473, 1, 1, 0.214, 0},
{29, 1, 0, 0.725, 1, 1, 0.267, 0},
{30, 1, 0, 0.446, 1, 1, 0.057, 0}
};
// 定义决策树结点
struct TreeNode {
int feature; // 当前结点选择的特征
double threshold; // 当前结点选择的阈值
bool leaf; // 是否为叶子结点
bool result; // 叶子结点的预测结果
TreeNode *left; // 左子树
TreeNode *right; // 右子树
};
// 计算熵
double entropy(int count_good, int count_bad) {
double p_good = (double)count_good / (double)(count_good + count_bad);
double p_bad = (double)count_bad / (double)(count_good + count_bad);
if (p_good == 0 || p_bad == 0) {
return 0;
}
return -p_good * log2(p_good) - p_bad * log2(p_bad);
}
// 计算数据集在某个特征上的条件熵
double conditional_entropy(vector<Sample> samples, int feature_index, double threshold, int &count_left_good, int &count_left_bad, int &count_right_good, int &count_right_bad) {
count_left_good = count_left_bad = count_right_good = count_right_bad = 0;
for (Sample sample : samples) {
if (sample.good) {
if (sample.index <= threshold) {
count_left_good++;
} else {
count_right_good++;
}
} else {
if (sample.index <= threshold) {
count_left_bad++;
} else {
count_right_bad++;
}
}
}
double p_left = (double)(count_left_good + count_left_bad) / (double)samples.size();
double p_right = (double)(count_right_good + count_right_bad) / (double)samples.size();
return p_left * entropy(count_left_good, count_left_bad) + p_right * entropy(count_right_good, count_right_bad);
}
// 选择最优特征和阈值
void choose_best_feature(vector<Sample> samples, int &best_feature_index, double &best_threshold) {
double min_entropy = 1e9;
int best_count_left_good, best_count_left_bad, best_count_right_good, best_count_right_bad;
for (int i = 0; i < 6; i++) {
double min_value = 1e9, max_value = -1e9;
for (Sample sample : samples) {
if (sample.feature(i) < min_value) {
min_value = sample.feature(i);
}
if (sample.feature(i) > max_value) {
max_value = sample.feature(i);
}
}
for (double threshold = min_value; threshold <= max_value; threshold += 0.1) {
int count_left_good, count_left_bad, count_right_good, count_right_bad;
double current_entropy = conditional_entropy(samples, i, threshold, count_left_good, count_left_bad, count_right_good, count_right_bad);
if (current_entropy < min_entropy) {
min_entropy = current_entropy;
best_feature_index = i;
best_threshold = threshold;
best_count_left_good = count_left_good;
best_count_left_bad = count_left_bad;
best_count_right_good = count_right_good;
best_count_right_bad = count_right_bad;
}
}
}
}
// 构建决策树
TreeNode* build_decision_tree(vector<Sample> samples) {
// 如果所有样本都属于同一类别,则直接返回
bool all_good = true, all_bad = true;
for (Sample sample : samples) {
if (sample.good) {
all_bad = false;
} else {
all_good = false;
}
}
if (all_good) {
TreeNode *node = new TreeNode();
node->leaf = true;
node->result = true;
return node;
}
if (all_bad) {
TreeNode *node = new TreeNode();
node->leaf = true;
node->result = false;
return node;
}
// 选择最优特征和阈值
int best_feature_index;
double best_threshold;
choose_best_feature(samples, best_feature_index, best_threshold);
// 构建子树
TreeNode *node = new TreeNode();
node->feature = best_feature_index;
node->threshold = best_threshold;
node->left = build_decision_tree(vector<Sample>(samples.begin(), samples.begin() + best_count_left_good + best_count_left_bad));
node->right = build_decision_tree(vector<Sample>(samples.begin() + best_count_left_good + best_count_left_bad, samples.end()));
return node;
}
// 预测样本的类别
bool predict(TreeNode *node, Sample sample) {
if (node->leaf) {
return node->result;
}
if (sample.feature(node->feature) <= node->threshold) {
return predict(node->left, sample);
} else {
return predict(node->right, sample);
}
}
int main() {
// 构建决策树
TreeNode *root = build_decision_tree(dataset);
// 预测样本的类别
Sample sample = {31, 1, 0, 0.697, 1, 1, 0.460, 0};
bool result = predict(root, sample);
cout << (result ? "Good" : "Bad") << endl;
// 释放决策树
delete root;
return 0;
}
```
这是一个简单的实现,可能存在一些性能问题和不足之处,但可以作为一个基础的参考。