用java写一个随机森林算法
时间: 2024-01-23 14:03:48 浏览: 64
好的,这是一个用Java实现随机森林算法的示例代码:
```java
import java.util.*;
public class RandomForest {
private List<DecisionTree> trees;
public RandomForest(int numTrees, int maxDepth, int numFeatures, List<Instance> instances) {
trees = new ArrayList<>();
for (int i = 0; i < numTrees; i++) {
List<Instance> bootstrapSample = new ArrayList<>();
for (int j = 0; j < instances.size(); j++) {
int index = (int) Math.floor(Math.random() * instances.size());
bootstrapSample.add(instances.get(index));
}
DecisionTree tree = new DecisionTree(maxDepth, numFeatures, bootstrapSample);
trees.add(tree);
}
}
public List<Double> predict(Instance instance) {
List<Double> predictions = new ArrayList<>();
for (DecisionTree tree : trees) {
predictions.add(tree.predict(instance));
}
return predictions;
}
public double evaluate(List<Instance> instances) {
int numCorrect = 0;
for (Instance instance : instances) {
List<Double> predictions = predict(instance);
double meanPrediction = 0.0;
for (double prediction : predictions) {
meanPrediction += prediction;
}
meanPrediction /= predictions.size();
if (meanPrediction >= 0.5 && instance.label == 1.0) {
numCorrect += 1;
} else if (meanPrediction < 0.5 && instance.label == 0.0) {
numCorrect += 1;
}
}
return (double) numCorrect / instances.size();
}
}
class DecisionTree {
private int maxDepth;
private int numFeatures;
private Node root;
public DecisionTree(int maxDepth, int numFeatures, List<Instance> instances) {
this.maxDepth = maxDepth;
this.numFeatures = numFeatures;
buildTree(instances);
}
public double predict(Instance instance) {
Node curr = root;
while (curr.left != null && curr.right != null) {
if (instance.features[curr.featureIndex] < curr.threshold) {
curr = curr.left;
} else {
curr = curr.right;
}
}
return curr.label;
}
public void buildTree(List<Instance> instances) {
root = buildTreeHelper(instances, 0);
}
private Node buildTreeHelper(List<Instance> instances, int depth) {
if (instances.size() == 0) { // no more data
return null;
}
if (depth >= maxDepth) { // depth limit exceeded
return new Node(getLabel(instances));
}
if (allInstancesSameLabel(instances)) { // all data have the same label
return new Node(getLabel(instances));
}
List<Integer> featureIndices = new ArrayList<>();
for (int i = 0; i < instances.get(0).features.length; i++) { // feature selection
featureIndices.add(i);
}
Collections.shuffle(featureIndices);
List<Integer> chosenFeatureIndices = featureIndices.subList(0, numFeatures);
double[] bestSplit = {0.0, 0.0}; // {featureIndex, threshold}
double bestInformationGain = -1.0;
for (int featureIndex : chosenFeatureIndices) { // find the best feature
List<Double> featureValues = new ArrayList<>();
for (Instance instance : instances) {
featureValues.add(instance.features[featureIndex]);
}
Collections.sort(featureValues);
for (int i = 1; i < featureValues.size(); i++) { // binary search for the best threshold
double threshold = (featureValues.get(i - 1) + featureValues.get(i)) / 2.0;
List<Instance> leftInstances = new ArrayList<>();
List<Instance> rightInstances = new ArrayList<>();
for (Instance instance : instances) { // partition data
if (instance.features[featureIndex] < threshold) {
leftInstances.add(instance);
} else {
rightInstances.add(instance);
}
}
double informationGain = getInformationGain(instances, leftInstances, rightInstances); // calculate information gain
if (informationGain > bestInformationGain) {
bestInformationGain = informationGain;
bestSplit[0] = featureIndex;
bestSplit[1] = threshold;
}
}
}
List<Instance> leftInstances = new ArrayList<>();
List<Instance> rightInstances = new ArrayList<>();
for (Instance instance : instances) { // partition data
if (instance.features[(int) bestSplit[0]] < bestSplit[1]) {
leftInstances.add(instance);
} else {
rightInstances.add(instance);
}
}
Node left = buildTreeHelper(leftInstances, depth + 1);
Node right = buildTreeHelper(rightInstances, depth + 1);
return new Node(bestSplit[0], bestSplit[1], left, right);
}
private boolean allInstancesSameLabel(List<Instance> instances) {
double firstLabel = instances.get(0).label;
for (Instance instance : instances) {
if (instance.label != firstLabel) {
return false;
}
}
return true;
}
private double getLabel(List<Instance> instances) {
double sum = 0.0;
for (Instance instance : instances) {
sum += instance.label;
}
return sum / instances.size();
}
private double getInformationGain(List<Instance> instances, List<Instance> leftInstances, List<Instance> rightInstances) {
return getEntropy(instances) - ((double) leftInstances.size() / instances.size()) * getEntropy(leftInstances) - ((double) rightInstances.size() / instances.size()) * getEntropy(rightInstances);
}
private double getEntropy(List<Instance> instances) {
double p = getLabel(instances);
if (p == 0 || p == 1) {
return 0;
}
return -p * Math.log(p) - (1 - p) * Math.log(1 - p);
}
}
class Node {
public int featureIndex; // the index of the feature to split on
public double threshold; // the threshold value to split on
public Node left; // the left child node
public Node right; // the right child node
public double label; // the predicted label
public Node(int featureIndex, double threshold, Node left, Node right) {
this.featureIndex = featureIndex;
this.threshold = threshold;
this.left = left;
this.right = right;
}
public Node(double label) {
this.label = label;
}
}
```
这个代码使用的是“Bagging + Decision Tree”的随机森林算法,在训练过程中会抽取训练数据的随机样本,再用这些样本来训练决策树。在预测时,每个决策树都会给出一个概率值,最终随机森林会统计各个决策树的概率值的平均值作为最终的预测结果。
阅读全文