多维分类模型——决策树 C++带类实现及案例
时间: 2023-10-23 21:13:02 浏览: 180
决策树是一种广泛使用的分类和回归方法,它可以用于多维分类问题。以下是使用C++类实现决策树的示例。
首先,我们需要定义一个节点类来表示决策树的节点:
```
class Node{
public:
int feature; //特征编号
double threshold; //阈值
int label; //标签
Node* left; //左子树
Node* right; //右子树
Node() {
feature = 0;
threshold = 0.0;
label = -1;
left = NULL;
right = NULL;
}
};
```
然后,我们需要定义一个决策树类来构建决策树:
```
class DecisionTree {
public:
DecisionTree();
~DecisionTree();
void buildTree(const vector<vector<double>>& data, const vector<int>& labels);
int predict(const vector<double>& data) const;
private:
Node* root;
void destroy(Node* node);
int getMajorityLabel(const vector<int>& labels) const;
int getBestFeature(const vector<vector<double>>& data, const vector<int>& labels, vector<double>& thresholds) const;
Node* buildSubTree(const vector<vector<double>>& data, const vector<int>& labels, const vector<double>& thresholds);
};
```
其中,buildTree()函数用于构建决策树,predict()函数用于进行预测。
下面是buildTree()函数的实现:
```
void DecisionTree::buildTree(const vector<vector<double>>& data, const vector<int>& labels) {
vector<double> thresholds(data[0].size(), 0.0);
root = buildSubTree(data, labels, thresholds);
}
Node* DecisionTree::buildSubTree(const vector<vector<double>>& data, const vector<int>& labels, const vector<double>& thresholds) {
Node* node = new Node;
if (labels.empty()) {
node->label = -1;
return node;
}
int majorityLabel = getMajorityLabel(labels);
if (majorityLabel == -1) {
node->label = majorityLabel;
return node;
}
int bestFeature = getBestFeature(data, labels, thresholds);
if (bestFeature == -1) {
node->label = majorityLabel;
return node;
}
node->feature = bestFeature;
node->threshold = thresholds[bestFeature];
vector<vector<double>> leftData;
vector<int> leftLabels;
vector<vector<double>> rightData;
vector<int> rightLabels;
for (int i = 0; i < data.size(); i++) {
if (data[i][bestFeature] <= thresholds[bestFeature]) {
leftData.push_back(data[i]);
leftLabels.push_back(labels[i]);
}
else {
rightData.push_back(data[i]);
rightLabels.push_back(labels[i]);
}
}
if (leftData.empty() || rightData.empty()) {
node->label = majorityLabel;
return node;
}
node->left = buildSubTree(leftData, leftLabels, thresholds);
node->right = buildSubTree(rightData, rightLabels, thresholds);
return node;
}
```
在buildSubTree()函数中,我们首先判断标签是否为空,如果为空,则返回一个空节点。然后,我们计算出标签数据中出现最多的标签,并将其作为节点的标签。接下来,我们选择最佳特征和阈值来划分数据。如果无法找到最佳特征,则返回一个具有多数标签的叶子节点。如果数据无法划分,则返回一个具有多数标签的叶子节点。否则,我们将数据分成左子树和右子树,并递归构建它们。
下面是predict()函数的实现:
```
int DecisionTree::predict(const vector<double>& data) const {
Node* node = root;
while (node->left != NULL && node->right != NULL) {
if (data[node->feature] <= node->threshold) {
node = node->left;
}
else {
node = node->right;
}
}
return node->label;
}
```
在predict()函数中,我们从根节点开始遍历决策树,并根据特征的值和阈值选择左子树或右子树,直到到达叶子节点。叶子节点的标签就是预测结果。
最后,我们需要定义一些辅助函数,如计算数据中出现最多的标签、选择最佳特征和阈值等等。这些函数的实现可以参考以下代码:
```
int DecisionTree::getMajorityLabel(const vector<int>& labels) const {
int numLabels = labels.size();
if (numLabels == 0) {
return -1;
}
unordered_map<int, int> labelCounts;
for (int i = 0; i < numLabels; i++) {
if (labelCounts.find(labels[i]) != labelCounts.end()) {
labelCounts[labels[i]]++;
}
else {
labelCounts[labels[i]] = 1;
}
}
int majorityLabel = -1;
int maxCount = -1;
for (auto it = labelCounts.begin(); it != labelCounts.end(); it++) {
if (it->second > maxCount) {
maxCount = it->second;
majorityLabel = it->first;
}
}
return majorityLabel;
}
int DecisionTree::getBestFeature(const vector<vector<double>>& data, const vector<int>& labels, vector<double>& thresholds) const {
int numFeatures = data[0].size();
int numLabels = labels.size();
double maxGain = -1.0;
int bestFeature = -1;
vector<double> featureValues(numLabels);
for (int i = 0; i < numFeatures; i++) {
for (int j = 0; j < numLabels; j++) {
featureValues[j] = data[j][i];
}
sort(featureValues.begin(), featureValues.end());
double threshold;
for (int j = 0; j < numLabels - 1; j++) {
threshold = (featureValues[j] + featureValues[j + 1]) / 2.0;
vector<int> leftLabels;
vector<int> rightLabels;
for (int k = 0; k < numLabels; k++) {
if (data[k][i] <= threshold) {
leftLabels.push_back(labels[k]);
}
else {
rightLabels.push_back(labels[k]);
}
}
double entropy = 0.0;
if (!leftLabels.empty()) {
double leftProb = (double)leftLabels.size() / numLabels;
entropy += -1.0 * leftProb * log2(leftProb);
}
if (!rightLabels.empty()) {
double rightProb = (double)rightLabels.size() / numLabels;
entropy += -1.0 * rightProb * log2(rightProb);
}
double gain = entropy;
if (gain > maxGain) {
maxGain = gain;
bestFeature = i;
thresholds[i] = threshold;
}
}
}
return bestFeature;
}
void DecisionTree::destroy(Node* node) {
if (node == NULL) {
return;
}
destroy(node->left);
destroy(node->right);
delete node;
}
DecisionTree::DecisionTree() {
root = NULL;
}
DecisionTree::~DecisionTree() {
destroy(root);
}
```
现在,我们可以使用上面定义的类来构建和使用决策树。以下是一个简单的例子:
```
int main() {
// 训练数据
vector<vector<double>> data = {{1.0, 2.0}, {2.0, 1.0}, {3.0, 4.0}, {4.0, 3.0}};
// 训练标签
vector<int> labels = {0, 0, 1, 1};
// 创建决策树
DecisionTree dt;
dt.buildTree(data, labels);
// 预测测试数据
vector<double> testData = {3.5, 2.5};
int pred = dt.predict(testData);
cout << "Prediction: " << pred << endl;
// 销毁决策树
return 0;
}
```
输出结果应该是“Prediction: 1”,表示测试数据属于标签1。
阅读全文