id3决策树java代码
时间: 2023-07-09 10:47:07 浏览: 120
以下是一个简单的 ID3 决策树的 Java 代码实现:
```
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ID3DecisionTree {
private static class Node {
private String label;
private List<Node> children;
public Node(String label) {
this.label = label;
this.children = new ArrayList<>();
}
public void addChild(Node child) {
children.add(child);
}
public String getLabel() {
return label;
}
public List<Node> getChildren() {
return children;
}
}
public static Node buildTree(List<Map<String, String>> data, List<String> attributes) {
if (data.isEmpty()) {
return null;
}
String majority = getMajorityClass(data);
if (allDataHasSameClass(data)) {
return new Node(majority);
}
if (attributes.isEmpty()) {
return new Node(majority);
}
String bestAttribute = getBestAttribute(data, attributes);
Node root = new Node(bestAttribute);
List<String> possibleValues = getPossibleValues(data, bestAttribute);
for (String value : possibleValues) {
List<Map<String, String>> subset = getSubset(data, bestAttribute, value);
List<String> newAttributes = new ArrayList<>(attributes);
newAttributes.remove(bestAttribute);
Node child = buildTree(subset, newAttributes);
child.label = value;
root.addChild(child);
}
return root;
}
private static String getMajorityClass(List<Map<String, String>> data) {
Map<String, Integer> classCounts = new HashMap<>();
for (Map<String, String> row : data) {
String clazz = row.get("class");
classCounts.put(clazz, classCounts.getOrDefault(clazz, 0) + 1);
}
String majorityClass = "";
int maxCount = 0;
for (Map.Entry<String, Integer> entry : classCounts.entrySet()) {
if (entry.getValue() > maxCount) {
majorityClass = entry.getKey();
maxCount = entry.getValue();
}
}
return majorityClass;
}
private static boolean allDataHasSameClass(List<Map<String, String>> data) {
String firstClass = data.get(0).get("class");
for (Map<String, String> row : data) {
if (!row.get("class").equals(firstClass)) {
return false;
}
}
return true;
}
private static String getBestAttribute(List<Map<String, String>> data, List<String> attributes) {
double minEntropy = Double.MAX_VALUE;
String bestAttribute = "";
for (String attribute : attributes) {
double entropy = getEntropy(data, attribute);
if (entropy < minEntropy) {
minEntropy = entropy;
bestAttribute = attribute;
}
}
return bestAttribute;
}
private static double getEntropy(List<Map<String, String>> data, String attribute) {
double entropy = 0.0;
List<String> possibleValues = getPossibleValues(data, attribute);
for (String value : possibleValues) {
List<Map<String, String>> subset = getSubset(data, attribute, value);
double probability = (double) subset.size() / data.size();
entropy -= probability * getLogBase2(probability);
}
return entropy;
}
private static List<String> getPossibleValues(List<Map<String, String>> data, String attribute) {
List<String> possibleValues = new ArrayList<>();
for (Map<String, String> row : data) {
String value = row.get(attribute);
if (!possibleValues.contains(value)) {
possibleValues.add(value);
}
}
return possibleValues;
}
private static List<Map<String, String>> getSubset(List<Map<String, String>> data, String attribute, String value) {
List<Map<String, String>> subset = new ArrayList<>();
for (Map<String, String> row : data) {
if (row.get(attribute).equals(value)) {
subset.add(row);
}
}
return subset;
}
private static double getLogBase2(double x) {
return Math.log(x) / Math.log(2);
}
}
```
这个实现使用了一个简单的 `Node` 类来表示决策树的每个节点。`buildTree` 方法采用递归方式来构建决策树,使用了 ID3 算法来选择最佳属性。`getMajorityClass` 和 `allDataHasSameClass` 方法用于计算数据集中的多数类和是否所有数据都属于同一类别。`getBestAttribute` 方法使用信息熵来选择最佳属性。`getEntropy` 方法用于计算熵。`getPossibleValues` 和 `getSubset` 方法用于处理数据集中的不同属性值。`getLogBase2` 方法用于计算以 2 为底数的对数。
阅读全文