用c++写一个判断西瓜好坏的决策树代码,并附上注释和训练数据
时间: 2023-06-15 09:04:42 浏览: 112
好的,下面是C++代码实现:
```c++
#include <iostream>
using namespace std;
// 定义西瓜的属性结构体
struct Watermelon {
int color; // 颜色,0代表青绿,1代表乌黑,2代表浅白
float root; // 根蒂,0代表蜷缩,1代表稍蜷,2代表硬挺
float sound; // 敲声,0代表沉闷,1代表浊响,2代表清脆
int good; // 是否好瓜,0代表否,1代表是
};
// 定义训练数据
Watermelon train_data[] = {
{0, 0, 0, 0},
{0, 0, 1, 0},
{1, 0, 1, 1},
{2, 0, 1, 1},
{2, 1, 0, 1},
{2, 2, 1, 0},
{1, 1, 0, 1},
{0, 2, 1, 0},
{0, 1, 0, 1},
{2, 1, 1, 1},
{0, 1, 1, 1},
{1, 0, 0, 0},
{1, 1, 1, 1},
{2, 1, 0, 0}
};
// 计算数据集中好瓜和坏瓜的个数
void count_good_bad(Watermelon *data, int len, int &good, int &bad) {
good = 0;
bad = 0;
for (int i = 0; i < len; i++) {
if (data[i].good == 1) {
good++;
} else {
bad++;
}
}
}
// 计算数据集中给定属性的某个值的好瓜和坏瓜的个数
void count_good_bad_by_attr(Watermelon *data, int len, int attr, float value, int &good, int &bad) {
good = 0;
bad = 0;
for (int i = 0; i < len; i++) {
if (data[i].good == 1 && data[i].color == value) {
good++;
} else if (data[i].good == 0 && data[i].color == value) {
bad++;
}
}
}
// 计算数据集中给定两个属性的某个值的好瓜和坏瓜的个数
void count_good_bad_by_attr(Watermelon *data, int len, int attr1, float value1, int attr2, float value2, int &good, int &bad) {
good = 0;
bad = 0;
for (int i = 0; i < len; i++) {
if (data[i].good == 1 && data[i].color == value1 && data[i].root == value2) {
good++;
} else if (data[i].good == 0 && data[i].color == value1 && data[i].root == value2) {
bad++;
}
}
}
// 计算数据集中给定三个属性的某个值的好瓜和坏瓜的个数
void count_good_bad_by_attr(Watermelon *data, int len, int attr1, float value1, int attr2, float value2, int attr3, float value3, int &good, int &bad) {
good = 0;
bad = 0;
for (int i = 0; i < len; i++) {
if (data[i].good == 1 && data[i].color == value1 && data[i].root == value2 && data[i].sound == value3) {
good++;
} else if (data[i].good == 0 && data[i].color == value1 && data[i].root == value2 && data[i].sound == value3) {
bad++;
}
}
}
// 训练决策树
void train_decision_tree(Watermelon *data, int len, int depth) {
int good, bad;
count_good_bad(data, len, good, bad);
if (good == 0 || bad == 0) {
if (good == 0) {
cout << "这是个坏瓜" << endl;
} else {
cout << "这是个好瓜" << endl;
}
return;
}
if (depth == 0) {
if (good > bad) {
cout << "这是个好瓜" << endl;
} else {
cout << "这是个坏瓜" << endl;
}
return;
}
// 选择最优属性
float info_gain, max_info_gain = -1;
int best_attr = -1;
for (int i = 0; i < 3; i++) {
int num_values;
float values[3];
if (i == 0) {
num_values = 3;
values[0] = 0;
values[1] = 1;
values[2] = 2;
} else {
num_values = 2;
values[0] = 0;
values[1] = 1;
}
for (int j = 0; j < num_values; j++) {
int temp_good, temp_bad;
if (i == 0) {
count_good_bad_by_attr(data, len, i, values[j], temp_good, temp_bad);
} else if (i == 1) {
count_good_bad_by_attr(data, len, i, values[j], 2, 0, temp_good, temp_bad);
} else {
count_good_bad_by_attr(data, len, i, values[j], 0, 0, 1, 1, temp_good, temp_bad);
}
float p_good = (float)temp_good / (float)(temp_good + temp_bad);
float p_bad = (float)temp_bad / (float)(temp_good + temp_bad);
float temp_info_gain = -p_good * log2(p_good) - p_bad * log2(p_bad);
if (temp_info_gain > max_info_gain) {
max_info_gain = temp_info_gain;
best_attr = i;
}
}
}
// 根据最优属性划分子数据集并递归训练
if (best_attr == 0) {
int num_values = 3;
float values[3];
values[0] = 0;
values[1] = 1;
values[2] = 2;
for (int j = 0; j < num_values; j++) {
int temp_good, temp_bad;
count_good_bad_by_attr(data, len, best_attr, values[j], temp_good, temp_bad);
cout << "如果颜色是" << values[j] << ":";
if (temp_good == 0 || temp_bad == 0) {
if (temp_good == 0) {
cout << "这是个坏瓜" << endl;
} else {
cout << "这是个好瓜" << endl;
}
} else {
Watermelon *sub_data = new Watermelon[temp_good + temp_bad];
int index = 0;
for (int k = 0; k < len; k++) {
if (data[k].color == values[j]) {
sub_data[index++] = data[k];
}
}
train_decision_tree(sub_data, index, depth - 1);
}
}
} else if (best_attr == 1) {
int num_values = 2;
float values[2];
values[0] = 0;
values[1] = 1;
for (int j = 0; j < num_values; j++) {
int temp_good, temp_bad;
count_good_bad_by_attr(data, len, best_attr, values[j], 2, 0, temp_good, temp_bad);
cout << "如果根蒂是" << values[j] << ":";
if (temp_good == 0 || temp_bad == 0) {
if (temp_good == 0) {
cout << "这是个坏瓜" << endl;
} else {
cout << "这是个好瓜" << endl;
}
} else {
Watermelon *sub_data = new Watermelon[temp_good + temp_bad];
int index = 0;
for (int k = 0; k < len; k++) {
if (data[k].root == values[j]) {
sub_data[index++] = data[k];
}
}
train_decision_tree(sub_data, index, depth - 1);
}
}
} else {
int num_values = 2;
float values[2];
values[0] = 0;
values[1] = 1;
for (int j = 0; j < num_values; j++) {
int temp_good, temp_bad;
count_good_bad_by_attr(data, len, best_attr, values[j], 0, 0, 1, 1, temp_good, temp_bad);
cout << "如果敲声是" << values[j] << ":";
if (temp_good == 0 || temp_bad == 0) {
if (temp_good == 0) {
cout << "这是个坏瓜" << endl;
} else {
cout << "这是个好瓜" << endl;
}
} else {
Watermelon *sub_data = new Watermelon[temp_good + temp_bad];
int index = 0;
for (int k = 0; k < len; k++) {
if (data[k].sound == values[j]) {
sub_data[index++] = data[k];
}
}
train_decision_tree(sub_data, index, depth - 1);
}
}
}
}
int main() {
train_decision_tree(train_data, 14, 3);
return 0;
}
```
训练数据如下:
```
颜色 根蒂 敲声 是否好瓜
0 0 0 0
0 0 1 0
1 0 1 1
2 0 1 1
2 1 0 1
2 2 1 0
1 1 0 1
0 2 1 0
0 1 0 1
2 1 1 1
0 1 1 1
1 0 0 0
1 1 1 1
2 1 0 0
```
其中,颜色属性的值为0代表青绿,1代表乌黑,2代表浅白;根蒂属性的值为0代表蜷缩,1代表稍蜷,2代表硬挺;敲声属性的值为0代表沉闷,1代表浊响,2代表清脆;是否好瓜属性的值为0代表否,1代表是。
阅读全文