cart分类算法 java代码
时间: 2023-10-10 16:12:36 浏览: 88
关于排序算法 java代码
以下是一个简单的使用 CART 决策树算法的 Java 代码示例:
```java
import java.util.ArrayList;
import java.util.HashMap;
public class CART {
public static void main(String[] args) {
// 构造训练集数据
ArrayList<HashMap<String, String>> trainData = new ArrayList<>();
HashMap<String, String> data1 = new HashMap<>();
data1.put("age", "青年");
data1.put("income", "高");
data1.put("student", "否");
data1.put("credit_rating", "一般");
data1.put("class", "不放贷");
trainData.add(data1);
HashMap<String, String> data2 = new HashMap<>();
data2.put("age", "青年");
data2.put("income", "高");
data2.put("student", "否");
data2.put("credit_rating", "好");
data2.put("class", "不放贷");
trainData.add(data2);
HashMap<String, String> data3 = new HashMap<>();
data3.put("age", "中年");
data3.put("income", "高");
data3.put("student", "否");
data3.put("credit_rating", "好");
data3.put("class", "放贷");
trainData.add(data3);
HashMap<String, String> data4 = new HashMap<>();
data4.put("age", "中年");
data4.put("income", "中等");
data4.put("student", "否");
data4.put("credit_rating", "好");
data4.put("class", "放贷");
trainData.add(data4);
HashMap<String, String> data5 = new HashMap<>();
data5.put("age", "中年");
data5.put("income", "中等");
data5.put("student", "是");
data5.put("credit_rating", "一般");
data5.put("class", "放贷");
trainData.add(data5);
HashMap<String, String> data6 = new HashMap<>();
data6.put("age", "老年");
data6.put("income", "中等");
data6.put("student", "是");
data6.put("credit_rating", "好");
data6.put("class", "放贷");
trainData.add(data6);
HashMap<String, String> data7 = new HashMap<>();
data7.put("age", "老年");
data7.put("income", "低");
data7.put("student", "是");
data7.put("credit_rating", "好");
data7.put("class", "不放贷");
trainData.add(data7);
HashMap<String, String> data8 = new HashMap<>();
data8.put("age", "老年");
data8.put("income", "低");
data8.put("student", "否");
data8.put("credit_rating", "一般");
data8.put("class", "不放贷");
trainData.add(data8);
// 训练决策树模型
DecisionTreeModel model = train(trainData);
System.out.println("决策树模型:" + model);
// 预测新数据
HashMap<String, String> newData = new HashMap<>();
newData.put("age", "青年");
newData.put("income", "中等");
newData.put("student", "否");
newData.put("credit_rating", "一般");
String result = predict(newData, model);
System.out.println("新数据预测结果:" + result);
}
/**
* 训练决策树模型
* @param trainData 训练集数据
* @return 决策树模型
*/
public static DecisionTreeModel train(ArrayList<HashMap<String, String>> trainData) {
// 获取训练集属性列表
ArrayList<String> attributeList = new ArrayList<>();
for (String key : trainData.get(0).keySet()) {
attributeList.add(key);
}
// 构建决策树模型
DecisionTreeModel model = new DecisionTreeModel();
buildDecisionTree(trainData, attributeList, model);
return model;
}
/**
* 构建决策树
* @param trainData 训练集数据
* @param attributeList 属性列表
* @param model 决策树模型
*/
public static void buildDecisionTree(ArrayList<HashMap<String, String>> trainData, ArrayList<String> attributeList, DecisionTreeModel model) {
// 如果训练集中所有数据属于同一类别,则将当前节点设置为叶子节点,并返回
boolean isSameClass = true;
String firstClass = trainData.get(0).get("class");
for (HashMap<String, String> data : trainData) {
if (!data.get("class").equals(firstClass)) {
isSameClass = false;
break;
}
}
if (isSameClass) {
model.isLeaf = true;
model.className = firstClass;
return;
}
// 如果属性列表为空,则将当前节点设置为叶子节点,并将其类别设置为训练集中最常见的类别
if (attributeList.isEmpty()) {
model.isLeaf = true;
model.className = getMostCommonClass(trainData);
return;
}
// 选择最佳属性(即使得信息增益最大的属性)
String bestAttribute = getBestAttribute(trainData, attributeList);
model.attributeName = bestAttribute;
// 根据最佳属性分裂训练集
ArrayList<ArrayList<HashMap<String, String>>> splitDataList = splitData(trainData, bestAttribute);
// 递归构建子树
ArrayList<String> newAttributeList = new ArrayList<>(attributeList);
newAttributeList.remove(bestAttribute); // 在属性列表中删除已经使用的属性
for (ArrayList<HashMap<String, String>> splitData : splitDataList) {
DecisionTreeModel subModel = new DecisionTreeModel();
model.subModelList.add(subModel);
buildDecisionTree(splitData, newAttributeList, subModel);
}
}
/**
* 预测新数据
* @param newData 新数据
* @param model 决策树模型
* @return 预测结果
*/
public static String predict(HashMap<String, String> newData, DecisionTreeModel model) {
// 如果当前节点是叶子节点,则返回其类别
if (model.isLeaf) {
return model.className;
}
// 根据当前节点的属性进行分裂
String attributeValue = newData.get(model.attributeName);
for (DecisionTreeModel subModel : model.subModelList) {
if (subModel.attributeValue.equals(attributeValue)) {
return predict(newData, subModel);
}
}
// 如果当前节点没有与新数据匹配的子节点,则将其类别设置为训练集中最常见的类别
return getMostCommonClass(model.trainData);
}
/**
* 获取训练集中最常见的类别
* @param trainData 训练集数据
* @return 最常见的类别
*/
public static String getMostCommonClass(ArrayList<HashMap<String, String>> trainData) {
HashMap<String, Integer> classCountMap = new HashMap<>();
for (HashMap<String, String> data : trainData) {
String className = data.get("class");
if (classCountMap.containsKey(className)) {
classCountMap.put(className, classCountMap.get(className) + 1);
} else {
classCountMap.put(className, 1);
}
}
String mostCommonClass = "";
int maxCount = -1;
for (String className : classCountMap.keySet()) {
int count = classCountMap.get(className);
if (count > maxCount) {
mostCommonClass = className;
maxCount = count;
}
}
return mostCommonClass;
}
/**
* 获取训练集中最佳属性
* @param trainData 训练集数据
* @param attributeList 属性列表
* @return 最佳属性
*/
public static String getBestAttribute(ArrayList<HashMap<String, String>> trainData, ArrayList<String> attributeList) {
String bestAttribute = "";
double maxInformationGain = -1;
for (String attribute : attributeList) {
double informationGain = calculateInformationGain(trainData, attribute);
if (informationGain > maxInformationGain) {
bestAttribute = attribute;
maxInformationGain = informationGain;
}
}
return bestAttribute;
}
/**
* 根据指定属性值分裂训练集
* @param trainData 训练集数据
* @param attributeName 属性名称
* @return 分裂后的数据集列表
*/
public static ArrayList<ArrayList<HashMap<String, String>>> splitData(ArrayList<HashMap<String, String>> trainData, String attributeName) {
ArrayList<ArrayList<HashMap<String, String>>> splitDataList = new ArrayList<>();
for (HashMap<String, String> data : trainData) {
String attributeValue = data.get(attributeName);
boolean isSplitDataExist = false;
for (ArrayList<HashMap<String, String>> splitData : splitDataList) {
if (splitData.get(0).get(attributeName).equals(attributeValue)) {
splitData.add(data);
isSplitDataExist = true;
break;
}
}
if (!isSplitDataExist) {
ArrayList<HashMap<String, String>> newSplitData = new ArrayList<>();
newSplitData.add(data);
splitDataList.add(newSplitData);
}
}
for (ArrayList<HashMap<String, String>> splitData : splitDataList) {
if (splitData.size() > 0) {
String attributeValue = splitData.get(0).get(attributeName);
DecisionTreeModel subModel = new DecisionTreeModel();
subModel.attributeName = attributeName;
subModel.attributeValue = attributeValue;
subModel.trainData = splitData;
}
}
return splitDataList;
}
/**
* 计算指定属性的信息增益
* @param trainData 训练集数据
* @param attributeName 属性名称
* @return 信息增益
*/
public static double calculateInformationGain(ArrayList<HashMap<String, String>> trainData, String attributeName) {
// 计算训练集的熵
double entropy = calculateEntropy(trainData);
// 计算分裂后的熵
double splitEntropy = 0;
ArrayList<ArrayList<HashMap<String, String>>> splitDataList = splitData(trainData, attributeName);
for (ArrayList<HashMap<String, String>> splitData : splitDataList) {
double splitDataEntropy = calculateEntropy(splitData);
double splitDataProbability = (double) splitData.size() / trainData.size();
splitEntropy += splitDataEntropy * splitDataProbability;
}
// 计算信息增益
double informationGain = entropy - splitEntropy;
return informationGain;
}
/**
* 计算数据集的熵
* @param dataList 数据集
* @return 熵
*/
public static double calculateEntropy(ArrayList<HashMap<String, String>> dataList) {
HashMap<String, Integer> classCountMap = new HashMap<>();
for (HashMap<String, String> data : dataList) {
String className = data.get("class");
if (classCountMap.containsKey(className)) {
classCountMap.put(className, classCountMap.get(className) + 1);
} else {
classCountMap.put(className, 1);
}
}
double entropy = 0;
for (String className : classCountMap.keySet()) {
double probability = (double) classCountMap.get(className) / dataList.size();
entropy -= probability * Math.log(probability) / Math.log(2);
}
return entropy;
}
}
/**
* 决策树模型
*/
class DecisionTreeModel {
public boolean isLeaf; // 是否是叶子节点
public String attributeName; // 分裂属性名称
public String attributeValue; // 分裂属性值
public ArrayList<DecisionTreeModel> subModelList; // 子模型列表
public String className; // 类别名称
public ArrayList<HashMap<String, String>> trainData; // 训练集数据
public DecisionTreeModel() {
this.isLeaf = false;
this.attributeName = "";
this.attributeValue = "";
this.subModelList = new ArrayList<>();
this.className = "";
this.trainData = new ArrayList<>();
}
public String toString() {
StringBuilder sb = new StringBuilder();
if (isLeaf) {
sb.append(className);
} else {
sb.append(attributeName + " -> ");
for (DecisionTreeModel subModel : subModelList) {
sb.append(subModel.attributeValue + ": " + subModel.toString() + "; ");
}
}
return sb.toString();
}
}
```
这个示例代码实现了一个简化的 CART 决策树算法,并提供了训练和预测的方法。由于数据集比较小,所以没有进行剪枝等优化操作。在实际应用中,可以根据具体情况进行改进。
阅读全文