c4.5决策树算法java
时间: 2023-11-20 13:54:38 浏览: 91
4.5决策树算法是一种基于信息增益的决策树算法,它可以用于分类和预测。下面是一个简单的Java实现:
```java
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class C45DecisionTree {
private Node root;
public void train(List<Map<String, String>> data, List<String> features) {
root = buildTree(data, features);
}
public String predict(Map<String, String> input) {
Node node = root;
while (!node.isLeaf()) {
String featureName = node.getFeatureName();
String featureValue = input.get(featureName);
node = node.getChild(featureValue);
}
return node.getFeatureValue();
}
private Node buildTree(List<Map<String, String>> data, List<String> features) {
if (data.isEmpty()) {
return new Node("unknown");
}
String commonFeatureValue = getCommonFeatureValue(data);
if (commonFeatureValue != null) {
return new Node(commonFeatureValue);
}
if (features.isEmpty()) {
return new Node(getMostCommonFeatureValue(data));
}
String bestFeature = getBestFeature(data, features);
Node node = new Node(bestFeature);
List<String> featureValues = getFeatureValues(data, bestFeature);
for (String featureValue : featureValues) {
List<Map<String, String>> subset = getSubset(data, bestFeature, featureValue);
List<String> remainingFeatures = new ArrayList<>(features);
remainingFeatures.remove(bestFeature);
Node child = buildTree(subset, remainingFeatures);
node.addChild(featureValue, child);
}
return node;
}
private String getCommonFeatureValue(List<Map<String, String>> data) {
String featureValue = null;
for (Map<String, String> row : data) {
String rowFeatureValue = row.get("play");
if (featureValue == null) {
featureValue = rowFeatureValue;
} else if (!featureValue.equals(rowFeatureValue)) {
return null;
}
}
return featureValue;
}
private String getMostCommonFeatureValue(List<Map<String, String>> data) {
Map<String, Integer> featureValueCounts = new HashMap<>();
for (Map<String, String> row : data) {
String featureValue = row.get("play");
featureValueCounts.put(featureValue, featureValueCounts.getOrDefault(featureValue, 0) + 1);
}
String mostCommonFeatureValue = null;
int mostCommonFeatureValueCount = 0;
for (Map.Entry<String, Integer> entry : featureValueCounts.entrySet()) {
if (entry.getValue() > mostCommonFeatureValueCount) {
mostCommonFeatureValue = entry.getKey();
mostCommonFeatureValueCount = entry.getValue();
}
}
return mostCommonFeatureValue;
}
private String getBestFeature(List<Map<String, String>> data, List<String> features) {
double maxGainRatio = 0;
String bestFeature = null;
for (String feature : features) {
double gainRatio = getGainRatio(data, feature);
if (gainRatio > maxGainRatio) {
maxGainRatio = gainRatio;
bestFeature = feature;
}
}
return bestFeature;
}
private double getGainRatio(List<Map<String, String>> data, String feature) {
double entropy = getEntropy(data);
double splitInfo = getSplitInfo(data, feature);
double featureEntropy = getFeatureEntropy(data, feature);
return (entropy - featureEntropy) / splitInfo;
}
private double getEntropy(List<Map<String, String>> data) {
Map<String, Integer> featureValueCounts = new HashMap<>();
for (Map<String, String> row : data) {
String featureValue = row.get("play");
featureValueCounts.put(featureValue, featureValueCounts.getOrDefault(featureValue, 0) + 1);
}
double entropy = 0;
for (int count : featureValueCounts.values()) {
double probability = (double) count / data.size();
entropy -= probability * Math.log(probability) / Math.log(2);
}
return entropy;
}
private double getSplitInfo(List<Map<String, String>> data, String feature) {
List<String> featureValues = getFeatureValues(data, feature);
double splitInfo = 0;
for (String featureValue : featureValues) {
List<Map<String, String>> subset = getSubset(data, feature, featureValue);
double probability = (double) subset.size() / data.size();
splitInfo -= probability * Math.log(probability) / Math.log(2);
}
return splitInfo;
}
private double getFeatureEntropy(List<Map<String, String>> data, String feature) {
List<String> featureValues = getFeatureValues(data, feature);
double featureEntropy = 0;
for (String featureValue : featureValues) {
List<Map<String, String>> subset = getSubset(data, feature, featureValue);
double probability = (double) subset.size() / data.size();
featureEntropy -= probability * Math.log(probability) / Math.log(2);
}
return featureEntropy;
}
private List<String> getFeatureValues(List<Map<String, String>> data, String feature) {
List<String> featureValues = new ArrayList<>();
for (Map<String, String> row : data) {
String featureValue = row.get(feature);
if (!featureValues.contains(featureValue)) {
featureValues.add(featureValue);
}
}
return featureValues;
}
private List<Map<String, String>> getSubset(List<Map<String, String>> data, String feature, String featureValue) {
List<Map<String, String>> subset = new ArrayList<>();
for (Map<String, String> row : data) {
if (row.get(feature).equals(featureValue)) {
subset.add(row);
}
}
return subset;
}
private static class Node {
private String featureName;
private String featureValue;
private Map<String, Node> children;
public Node(String featureValue) {
this.featureValue = featureValue;
}
public Node(String featureName, String featureValue) {
this.featureName = featureName;
this.featureValue = featureValue;
children = new HashMap<>();
}
public boolean isLeaf() {
return children == null || children.isEmpty();
}
public String getFeatureName() {
return featureName;
}
public String getFeatureValue() {
return featureValue;
}
public void addChild(String featureValue, Node child) {
children.put(featureValue, child);
}
public Node getChild(String featureValue) {
return children.get(featureValue);
}
}
}
```
阅读全文