写一个c++西瓜决策树代码
时间: 2023-06-15 14:04:41 浏览: 83
以下是一个简单的西瓜决策树的 C++ 代码实现:
```c++
#include<iostream>
#include<vector>
#include<cmath>
using namespace std;
// 定义西瓜类
class Watermelon {
public:
int color; // 颜色,0:青绿,1:乌黑,2:浅白
int root; // 根蒂,0:蜷缩,1:稍蜷,2:硬挺
int sound; // 敲声,0:清脆,1:浊响
int weight; // 重量,0:<0.5kg,1:0.5~0.75kg,2:>0.75kg
int type; // 类型,0:好瓜,1:坏瓜
};
// 计算信息熵
double entropy(const vector<Watermelon>& dataset) {
int n = dataset.size();
int cnt[2] = { 0 };
for (int i = 0; i < n; i++) {
cnt[dataset[i].type]++;
}
double ent = 0;
for (int i = 0; i < 2; i++) {
if (cnt[i] > 0) {
double p = (double)cnt[i] / n;
ent -= p * log2(p);
}
}
return ent;
}
// 计算信息增益
double infoGain(const vector<Watermelon>& dataset, int feature) {
int n = dataset.size();
int cnt[3][2] = { 0 };
for (int i = 0; i < n; i++) {
cnt[dataset[i].color][dataset[i].type]++;
}
double gain = entropy(dataset);
for (int i = 0; i < 3; i++) {
int cntf[2] = { cnt[i][0], cnt[i][1] };
double ent = entropy(vector<Watermelon>(dataset.begin(), dataset.end()));
for (int j = 0; j < 2; j++) {
if (cntf[j] > 0) {
double p = (double)cntf[j] / n;
ent -= p * log2(p);
}
}
gain -= (double)cnt[i][0] / n * ent;
}
return gain;
}
// 定义节点类
class TreeNode {
public:
int feature; // 分裂特征
int threshold; // 分裂阈值
int type; // 类型,0:好瓜,1:坏瓜
TreeNode* left; // 左子树
TreeNode* right; // 右子树
TreeNode(int f, int t) {
feature = f;
threshold = -1;
type = t;
left = nullptr;
right = nullptr;
}
~TreeNode() {
if (left != nullptr) {
delete left;
}
if (right != nullptr) {
delete right;
}
}
};
// 构建决策树
void buildDecisionTree(TreeNode* node, const vector<Watermelon>& dataset, const vector<int>& features) {
if (features.size() == 0) {
return;
}
int n = dataset.size();
int cnt[2] = { 0 };
for (int i = 0; i < n; i++) {
cnt[dataset[i].type]++;
}
if (cnt[0] == n) {
node->type = 0;
return;
}
if (cnt[1] == n) {
node->type = 1;
return;
}
double maxGain = -1;
int bestFeature = -1;
int bestThreshold = -1;
for (int i = 0; i < features.size(); i++) {
int f = features[i];
int cntf[3][2] = { 0 };
for (int j = 0; j < n; j++) {
cntf[dataset[j].color][dataset[j].type]++;
}
for (int j = 0; j < 3; j++) {
int cntfj[2] = { cntf[j][0], cntf[j][1] };
double ent = entropy(vector<Watermelon>(dataset.begin(), dataset.end()));
for (int k = 0; k < 2; k++) {
if (cntfj[k] > 0) {
double p = (double)cntfj[k] / n;
ent -= p * log2(p);
}
}
if (ent > maxGain) {
maxGain = ent;
bestFeature = f;
bestThreshold = j;
}
}
}
if (maxGain <= 0) {
return;
}
node->feature = bestFeature;
node->threshold = bestThreshold;
vector<int> leftIndices, rightIndices;
for (int i = 0; i < n; i++) {
if (dataset[i].color < bestThreshold) {
leftIndices.push_back(i);
}
else {
rightIndices.push_back(i);
}
}
node->left = new TreeNode(-1, -1);
buildDecisionTree(node->left, vector<Watermelon>(dataset.begin(), dataset.end()), features);
node->right = new TreeNode(-1, -1);
buildDecisionTree(node->right, vector<Watermelon>(dataset.begin(), dataset.end()), features);
}
// 预测
int predict(const TreeNode* node, const Watermelon& wm) {
if (node->feature < 0) {
return node->type;
}
if (wm.color < node->threshold) {
return predict(node->left, wm);
}
else {
return predict(node->right, wm);
}
}
int main() {
vector<Watermelon> dataset{
{0, 0, 0, 0, 0},
{0, 0, 0, 1, 0},
{1, 0, 0, 0, 1},
{2, 1, 0, 0, 1},
{2, 2, 1, 0, 1},
{2, 2, 1, 1, 0},
{1, 2, 1, 1, 1},
{0, 1, 0, 0, 0},
{0, 2, 1, 0, 1},
{2, 1, 1, 0, 1},
{0, 1, 1, 1, 1},
{1, 1, 0, 1, 1},
{1, 0, 1, 0, 1},
{2, 1, 0, 1, 0}
};
vector<int> features{ 0, 1, 2, 3 };
TreeNode* root = new TreeNode(-1, -1);
buildDecisionTree(root, dataset, features);
cout << predict(root, { 0, 1, 0, 1, 0 }) << endl; // 输出 0,表示好瓜
cout << predict(root, { 2, 0, 1, 1, 0 }) << endl; // 输出 1,表示坏瓜
delete root;
return 0;
}
```
以上代码实现了西瓜决策树的构建与预测。由于西瓜数据集较小,所以该实现是比较简单的。在实际应用中,需要考虑更多的特征和更大的数据集,以提高决策树的准确性。