CART算法java实现
时间: 2023-10-22 08:18:25 浏览: 159
CART(Classification and Regression Trees)是一种决策树算法,用于分类和回归问题。下面是CART算法的Java实现。
首先,我们需要定义一个节点类,用于表示CART决策树中的每个节点:
```
public class Node {
private int featureIndex; // 特征索引
private double threshold; // 阈值
private double value; // 结果值(叶子节点)
private Node left; // 左节点
private Node right; // 右节点
// 构造函数
public Node(int featureIndex, double threshold, double value, Node left, Node right) {
this.featureIndex = featureIndex;
this.threshold = threshold;
this.value = value;
this.left = left;
this.right = right;
}
// 判断是否为叶子节点
public boolean isLeaf() {
return left == null && right == null;
}
// Getters and setters
public int getFeatureIndex() {
return featureIndex;
}
public void setFeatureIndex(int featureIndex) {
this.featureIndex = featureIndex;
}
public double getThreshold() {
return threshold;
}
public void setThreshold(double threshold) {
this.threshold = threshold;
}
public double getValue() {
return value;
}
public void setValue(double value) {
this.value = value;
}
public Node getLeft() {
return left;
}
public void setLeft(Node left) {
this.left = left;
}
public Node getRight() {
return right;
}
public void setRight(Node right) {
this.right = right;
}
}
```
然后,我们需要定义一个CART类,用于训练和预测:
```
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class CART {
private Node root; // CART树的根节点
private int maxDepth; // 最大深度
private int minSamplesSplit; // 最小样本数
// 构造函数
public CART(int maxDepth, int minSamplesSplit) {
this.maxDepth = maxDepth;
this.minSamplesSplit = minSamplesSplit;
}
// 训练函数
public void fit(double[][] X, double[] y) {
root = buildTree(X, y, 0);
}
// 预测函数
public double predict(double[] x) {
Node node = root;
while (!node.isLeaf()) {
if (x[node.getFeatureIndex()] <= node.getThreshold()) {
node = node.getLeft();
} else {
node = node.getRight();
}
}
return node.getValue();
}
// 构建决策树
private Node buildTree(double[][] X, double[] y, int depth) {
int nSamples = X.length;
int nFeatures = X[0].length;
// 如果样本数小于最小样本数或者达到最大深度,则返回叶子节点
if (nSamples < minSamplesSplit || depth == maxDepth) {
return new Node(-1, -1, mean(y), null, null);
}
double impurity = impurity(y);
double bestImpurity = Double.POSITIVE_INFINITY;
int bestFeatureIndex = 0;
double bestThreshold = 0;
// 寻找最佳划分特征和阈值
for (int i = 0; i < nFeatures; i++) {
double[] featureValues = new double[nSamples];
for (int j = 0; j < nSamples; j++) {
featureValues[j] = X[j][i];
}
Arrays.sort(featureValues);
for (int j = 0; j < nSamples - 1; j++) {
double threshold = (featureValues[j] + featureValues[j + 1]) / 2;
List<double[]> splits = split(X, y, i, threshold);
double leftImpurity = impurity(splits.get(0));
double rightImpurity = impurity(splits.get(1));
double impurityReduction = impurity - (splits.get(0).length * leftImpurity + splits.get(1).length * rightImpurity) / nSamples;
if (impurityReduction < bestImpurity) {
bestImpurity = impurityReduction;
bestFeatureIndex = i;
bestThreshold = threshold;
}
}
}
// 如果无法继续降低不纯度,则返回叶子节点
if (bestImpurity == Double.POSITIVE_INFINITY) {
return new Node(-1, -1, mean(y), null, null);
}
// 划分数据集
List<double[]> leftX = new ArrayList<>();
List<double[]> rightX = new ArrayList<>();
List<Double> leftY = new ArrayList<>();
List<Double> rightY = new ArrayList<>();
for (int i = 0; i < nSamples; i++) {
if (X[i][bestFeatureIndex] <= bestThreshold) {
leftX.add(X[i]);
leftY.add(y[i]);
} else {
rightX.add(X[i]);
rightY.add(y[i]);
}
}
Node left = buildTree(listToArray(leftX), listToArray(leftY), depth + 1);
Node right = buildTree(listToArray(rightX), listToArray(rightY), depth + 1);
return new Node(bestFeatureIndex, bestThreshold, -1, left, right);
}
// 计算不纯度
private double impurity(double[] y) {
double n = y.length;
double count1 = 0, count2 = 0;
for (int i = 0; i < n; i++) {
if (y[i] == 1) {
count1++;
} else {
count2++;
}
}
double p1 = count1 / n;
double p2 = count2 / n;
return 1 - p1 * p1 - p2 * p2;
}
// 划分数据集
private List<double[]> split(double[][] X, double[] y, int featureIndex, double threshold) {
List<double[]> leftX = new ArrayList<>();
List<double[]> rightX = new ArrayList<>();
List<Double> leftY = new ArrayList<>();
List<Double> rightY = new ArrayList<>();
for (int i = 0; i < X.length; i++) {
if (X[i][featureIndex] <= threshold) {
leftX.add(X[i]);
leftY.add(y[i]);
} else {
rightX.add(X[i]);
rightY.add(y[i]);
}
}
List<double[]> splits = new ArrayList<>();
splits.add(listToArray(leftX));
splits.add(listToArray(rightX));
return splits;
}
// 计算均值
private double mean(double[] y) {
double sum = 0;
for (double value : y) {
sum += value;
}
return sum / y.length;
}
// List转数组
private double[][] listToArray(List<double[]> list) {
int m = list.size();
int n = list.get(0).length;
double[][] array = new double[m][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
array[i][j] = list.get(i)[j];
}
}
return array;
}
// Getters and setters
public Node getRoot() {
return root;
}
public void setRoot(Node root) {
this.root = root;
}
public int getMaxDepth() {
return maxDepth;
}
public void setMaxDepth(int maxDepth) {
this.maxDepth = maxDepth;
}
public int getMinSamplesSplit() {
return minSamplesSplit;
}
public void setMinSamplesSplit(int minSamplesSplit) {
this.minSamplesSplit = minSamplesSplit;
}
}
```
最后,我们可以使用以下代码进行训练和预测:
```
public class Main {
public static void main(String[] args) {
double[][] X = {{2.0, 4.0}, {3.0, 6.0}, {4.0, 8.0}, {5.0, 10.0}, {6.0, 12.0}};
double[] y = {1, 1, 1, 2, 2};
CART cart = new CART(2, 2);
cart.fit(X, y);
double[] x1 = {3.5, 7.0};
double[] x2 = {5.5, 11.0};
double y1 = cart.predict(x1);
double y2 = cart.predict(x2);
System.out.println(y1); // 输出1.0
System.out.println(y2); // 输出2.0
}
}
```
上述代码使用CART算法训练了一个决策树,并使用该决策树预测了两个样本的类别。
阅读全文