Java 实现决策树算法
时间: 2023-11-14 16:12:33 浏览: 400
决策树算法是一种常见的机器学习算法,可以用于分类和回归问题。在 Java 中,实现决策树算法需要使用递归和面向对象的编程思想。
下面是一个简单的决策树分类器的 Java 实现:
```java
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class DecisionTree {
private Node root;
public DecisionTree() {}
public void fit(ArrayList<ArrayList<String>> data, ArrayList<String> labels) {
this.root = buildTree(data, labels);
}
public String predict(ArrayList<String> sample) {
return classify(sample, this.root);
}
private String classify(ArrayList<String> sample, Node node) {
if (node.isLeaf()) {
return node.getLabel();
}
String feature = node.getFeature();
String value = sample.get(node.getIndex(feature));
Node child = node.getChildren().get(value);
return classify(sample, child);
}
private Node buildTree(ArrayList<ArrayList<String>> data, ArrayList<String> labels) {
if (labels.isEmpty()) {
return new Node(getMajorityLabel(labels));
}
if (isHomogeneous(labels)) {
return new Node(labels.get(0));
}
if (data.isEmpty()) {
return new Node(getMajorityLabel(labels));
}
String feature = getBestFeature(data, labels);
Node node = new Node(feature);
for (String value : getUniqueValues(data, feature)) {
ArrayList<ArrayList<String>> subset = getSubset(data, labels, feature, value);
Node child = buildTree(subset, getSubsetLabels(labels, subset));
node.addChild(value, child);
}
return node;
}
private ArrayList<String> getSubsetLabels(ArrayList<String> labels, ArrayList<ArrayList<String>> subset) {
ArrayList<String> subsetLabels = new ArrayList<>();
for (ArrayList<String> sample : subset) {
subsetLabels.add(labels.get(data.indexOf(sample)));
}
return subsetLabels;
}
private ArrayList<ArrayList<String>> getSubset(ArrayList<ArrayList<String>> data, ArrayList<String> labels, String feature, String value) {
ArrayList<ArrayList<String>> subset = new ArrayList<>();
for (int i = 0; i < data.size(); i++) {
ArrayList<String> sample = data.get(i);
if (sample.get(getIndex(feature)).equals(value)) {
subset.add(sample);
}
}
return subset;
}
private ArrayList<String> getUniqueValues(ArrayList<ArrayList<String>> data, String feature) {
ArrayList<String> uniqueValues = new ArrayList<>();
int index = getIndex(feature);
for (ArrayList<String> sample : data) {
String value = sample.get(index);
if (!uniqueValues.contains(value)) {
uniqueValues.add(value);
}
}
return uniqueValues;
}
private int getIndex(String feature) {
return this.root.getFeatures().indexOf(feature);
}
private String getBestFeature(ArrayList<ArrayList<String>> data, ArrayList<String> labels) {
double maxGain = -1;
String bestFeature = null;
double parentEntropy = getEntropy(labels);
for (String feature : this.root.getFeatures()) {
double gain = parentEntropy - getConditionalEntropy(data, labels, feature);
if (gain > maxGain) {
maxGain = gain;
bestFeature = feature;
}
}
return bestFeature;
}
private double getConditionalEntropy(ArrayList<ArrayList<String>> data, ArrayList<String> labels, String feature) {
double conditionalEntropy = 0;
Map<String, ArrayList<String>> subsets = getSubsets(data, feature);
for (String value : subsets.keySet()) {
ArrayList<String> subsetLabels = getSubsetLabels(labels, subsets.get(value));
double probability = (double) subsets.get(value).size() / data.size();
conditionalEntropy += probability * getEntropy(subsetLabels);
}
return conditionalEntropy;
}
private Map<String, ArrayList<String>> getSubsets(ArrayList<ArrayList<String>> data, String feature) {
Map<String, ArrayList<String>> subsets = new HashMap<>();
int index = getIndex(feature);
for (ArrayList<String> sample : data) {
String value = sample.get(index);
if (!subsets.containsKey(value)) {
subsets.put(value, new ArrayList<>());
}
subsets.get(value).add(sample);
}
return subsets;
}
private double getEntropy(ArrayList<String> labels) {
double entropy = 0;
Map<String, Integer> counts = getCounts(labels);
for (Integer count : counts.values()) {
double probability = (double) count / labels.size();
entropy -= probability * Math.log(probability) / Math.log(2);
}
return entropy;
}
private String getMajorityLabel(ArrayList<String> labels) {
Map<String, Integer> counts = getCounts(labels);
int maxCount = -1;
String majorityLabel = null;
for (String label : counts.keySet()) {
int count = counts.get(label);
if (count > maxCount) {
maxCount = count;
majorityLabel = label;
}
}
return majorityLabel;
}
private boolean isHomogeneous(ArrayList<String> labels) {
String firstLabel = labels.get(0);
for (String label : labels) {
if (!label.equals(firstLabel)) {
return false;
}
}
return true;
}
private Map<String, Integer> getCounts(ArrayList<String> labels) {
Map<String, Integer> counts = new HashMap<>();
for (String label : labels) {
if (!counts.containsKey(label)) {
counts.put(label, 0);
}
counts.put(label, counts.get(label) + 1);
}
return counts;
}
private class Node {
private String feature;
private ArrayList<String> features;
private String label;
private Map<String, Node> children;
public Node(String feature) {
this.feature = feature;
this.children = new HashMap<>();
}
public Node(String label) {
this.label = label;
}
public String getFeature() {
return this.feature;
}
public ArrayList<String> getFeatures() {
return this.features;
}
public String getLabel() {
return this.label;
}
public Map<String, Node> getChildren() {
return this.children;
}
public boolean isLeaf() {
return this.label != null;
}
public void addChild(String value, Node child) {
this.children.put(value, child);
}
}
}
```
在这个实现中,`DecisionTree` 类包含了决策树的构建、训练和预测方法。`Node` 类表示决策树节点,包含了节点的特征、标签和子节点等信息。这里使用了递归的方法构建决策树,每次递归都会选择最佳的特征进行划分,直到满足停止条件为止。
下面是一个使用上述决策树分类器的例子:
```java
public static void main(String[] args) {
ArrayList<ArrayList<String>> data = new ArrayList<>();
data.add(new ArrayList<>(Arrays.asList("sunny", "hot", "high", "weak")));
data.add(new ArrayList<>(Arrays.asList("sunny", "hot", "high", "strong")));
data.add(new ArrayList<>(Arrays.asList("overcast", "hot", "high", "weak")));
data.add(new ArrayList<>(Arrays.asList("rainy", "mild", "high", "weak")));
data.add(new ArrayList<>(Arrays.asList("rainy", "cool", "normal", "weak")));
data.add(new ArrayList<>(Arrays.asList("rainy", "cool", "normal", "strong")));
data.add(new ArrayList<>(Arrays.asList("overcast", "cool", "normal", "strong")));
data.add(new ArrayList<>(Arrays.asList("sunny", "mild", "high", "weak")));
data.add(new ArrayList<>(Arrays.asList("sunny", "cool", "normal", "weak")));
data.add(new ArrayList<>(Arrays.asList("rainy", "mild", "normal", "weak")));
data.add(new ArrayList<>(Arrays.asList("sunny", "mild", "normal", "strong")));
data.add(new ArrayList<>(Arrays.asList("overcast", "mild", "high", "strong")));
data.add(new ArrayList<>(Arrays.asList("overcast", "hot", "normal", "weak")));
data.add(new ArrayList<>(Arrays.asList("rainy", "mild", "high", "strong")));
ArrayList<String> labels = new ArrayList<>(Arrays.asList("no", "no", "yes", "yes", "yes", "no", "yes", "no", "yes", "yes", "yes", "yes", "yes", "no"));
DecisionTree dt = new DecisionTree();
dt.fit(data, labels);
ArrayList<String> sample = new ArrayList<>(Arrays.asList("sunny", "hot", "high", "weak"));
String prediction = dt.predict(sample);
System.out.println(prediction);
}
```
这个例子中,我们使用了一个简单的天气数据集,包含了天气状况和是否打高尔夫的标签。我们先构建了一个 `DecisionTree` 对象,然后调用 `fit` 方法进行训练,最后使用 `predict` 方法对新样本进行预测。
阅读全文