java实现决策树代码‘
时间: 2023-06-20 13:10:28 浏览: 61
以下是一个简单的Java实现决策树的示例代码:
```
import java.util.ArrayList;
import java.util.HashMap;
public class DecisionTree {
private Node root;
private ArrayList<String> attributes;
private class Node {
private String attribute;
private HashMap<String, Node> children;
private boolean isLeaf;
private String classification;
public Node(String attribute) {
this.attribute = attribute;
this.children = new HashMap<String, Node>();
this.isLeaf = false;
}
public void setClassification(String classification) {
this.isLeaf = true;
this.classification = classification;
}
public void addChild(String value, Node child) {
this.children.put(value, child);
}
public boolean isLeaf() {
return this.isLeaf;
}
public String getClassification() {
return this.classification;
}
public Node getChild(String value) {
return this.children.get(value);
}
}
public DecisionTree(ArrayList<String> attributes) {
this.root = null;
this.attributes = attributes;
}
public void train(ArrayList<ArrayList<String>> data) {
this.root = buildTree(data, this.attributes);
}
public String classify(ArrayList<String> data) {
return classify(data, this.root);
}
private Node buildTree(ArrayList<ArrayList<String>> data, ArrayList<String> attributes) {
if (data.size() == 0) {
return new Node("null");
}
String classification = getClassification(data);
if (classification != null) {
Node leaf = new Node(null);
leaf.setClassification(classification);
return leaf;
}
if (attributes.size() == 0) {
return new Node("null");
}
String bestAttribute = getBestAttribute(data, attributes);
Node tree = new Node(bestAttribute);
ArrayList<String> values = getAttributeValues(data, attributes, bestAttribute);
for (String value : values) {
ArrayList<ArrayList<String>> newData = getSubset(data, attributes, bestAttribute, value);
Node subtree = buildTree(newData, attributes);
tree.addChild(value, subtree);
}
return tree;
}
private String classify(ArrayList<String> data, Node node) {
if (node.isLeaf()) {
return node.getClassification();
}
String attribute = node.attribute;
Node child = node.getChild(data.get(this.attributes.indexOf(attribute)));
return classify(data, child);
}
private String getClassification(ArrayList<ArrayList<String>> data) {
String classification = data.get(0).get(data.get(0).size() - 1);
boolean same = true;
for (ArrayList<String> instance : data) {
if (!instance.get(instance.size() - 1).equals(classification)) {
same = false;
break;
}
}
if (same) {
return classification;
} else {
return null;
}
}
private String getBestAttribute(ArrayList<ArrayList<String>> data, ArrayList<String> attributes) {
double bestGain = 0;
String bestAttribute = null;
for (String attribute : attributes) {
double gain = getInformationGain(data, attribute);
if (gain > bestGain) {
bestGain = gain;
bestAttribute = attribute;
}
}
return bestAttribute;
}
private double getInformationGain(ArrayList<ArrayList<String>> data, String attribute) {
double entropyBefore = getEntropy(data);
ArrayList<String> values = getAttributeValues(data, this.attributes, attribute);
double entropyAfter = 0;
for (String value : values) {
ArrayList<ArrayList<String>> subset = getSubset(data, this.attributes, attribute, value);
double probability = (double) subset.size() / data.size();
entropyAfter += probability * getEntropy(subset);
}
return entropyBefore - entropyAfter;
}
private double getEntropy(ArrayList<ArrayList<String>> data) {
HashMap<String, Integer> counts = getClassCounts(data);
double entropy = 0;
for (String key : counts.keySet()) {
double probability = (double) counts.get(key) / data.size();
entropy -= probability * Math.log(probability) / Math.log(2);
}
return entropy;
}
private ArrayList<String> getAttributeValues(ArrayList<ArrayList<String>> data, ArrayList<String> attributes, String attribute) {
ArrayList<String> values = new ArrayList<String>();
int index = attributes.indexOf(attribute);
for (ArrayList<String> instance : data) {
String value = instance.get(index);
if (!values.contains(value)) {
values.add(value);
}
}
return values;
}
private ArrayList<ArrayList<String>> getSubset(ArrayList<ArrayList<String>> data, ArrayList<String> attributes, String attribute, String value) {
ArrayList<ArrayList<String>> subset = new ArrayList<ArrayList<String>>();
int index = attributes.indexOf(attribute);
for (ArrayList<String> instance : data) {
if (instance.get(index).equals(value)) {
ArrayList<String> newInstance = new ArrayList<String>();
for (int i = 0; i < instance.size(); i++) {
if (i != index) {
newInstance.add(instance.get(i));
}
}
subset.add(newInstance);
}
}
return subset;
}
private HashMap<String, Integer> getClassCounts(ArrayList<ArrayList<String>> data) {
HashMap<String, Integer> counts = new HashMap<String, Integer>();
for (ArrayList<String> instance : data) {
String classification = instance.get(instance.size() - 1);
if (!counts.containsKey(classification)) {
counts.put(classification, 1);
} else {
counts.put(classification, counts.get(classification) + 1);
}
}
return counts;
}
}
```
这个决策树实现是基于ID3算法的,可以通过训练数据来构建决策树,并使用决策树进行分类。在实现中,使用了递归来构建树,并使用哈希表来管理树节点。