adaboost算法java代码_04-04 AdaBoost算法代码(鸢尾花分类)
时间: 2023-07-05 22:25:04 浏览: 121
好的,以下是使用Java实现的AdaBoost算法代码,用于鸢尾花分类:
```java
import java.util.ArrayList;
import java.util.Collections;
public class AdaBoost {
// 训练数据集
private ArrayList<ArrayList<Double>> dataSet;
// 类别标签
private ArrayList<Integer> labels;
// 弱分类器数目
private int weakNum;
// 训练好的弱分类器集合
private ArrayList<WeakClassifier> weakClassifiers;
// AdaBoost构造函数
public AdaBoost(ArrayList<ArrayList<Double>> dataSet, ArrayList<Integer> labels, int weakNum) {
this.dataSet = dataSet;
this.labels = labels;
this.weakNum = weakNum;
this.weakClassifiers = new ArrayList<>();
}
// 训练分类器
public void train() {
int size = dataSet.size();
// 初始化权重向量
ArrayList<Double> weights = new ArrayList<>();
for (int i = 0; i < size; i++) {
weights.add(1.0 / size);
}
// 训练 weakNum 个弱分类器
for (int i = 0; i < weakNum; i++) {
// 训练单个弱分类器
WeakClassifier weakClassifier = new WeakClassifier(dataSet, labels, weights);
weakClassifier.train();
// 计算错误率
double error = 0.0;
for (int j = 0; j < size; j++) {
if (weakClassifier.predict(dataSet.get(j)) != labels.get(j)) {
error += weights.get(j);
}
}
// 计算弱分类器权重
double alpha = 0.5 * Math.log((1 - error) / error);
weakClassifier.setAlpha(alpha);
// 更新权重向量
for (int j = 0; j < size; j++) {
if (weakClassifier.predict(dataSet.get(j)) == labels.get(j)) {
weights.set(j, weights.get(j) * Math.exp(-alpha));
} else {
weights.set(j, weights.get(j) * Math.exp(alpha));
}
}
// 归一化权重向量
double sum = 0.0;
for (int j = 0; j < size; j++) {
sum += weights.get(j);
}
for (int j = 0; j < size; j++) {
weights.set(j, weights.get(j) / sum);
}
// 将训练好的弱分类器加入集合
weakClassifiers.add(weakClassifier);
}
}
// 预测分类结果
public int predict(ArrayList<Double> data) {
double sum = 0.0;
for (WeakClassifier wc : weakClassifiers) {
sum += wc.predict(data) * wc.getAlpha();
}
if (sum > 0) {
return 1;
} else {
return -1;
}
}
// 测试分类器
public void test(ArrayList<ArrayList<Double>> testData, ArrayList<Integer> testLabels) {
int errorNum = 0;
int size = testData.size();
for (int i = 0; i < size; i++) {
if (predict(testData.get(i)) != testLabels.get(i)) {
errorNum++;
}
}
double accuracy = 1 - (double) errorNum / size;
System.out.println("Accuracy: " + accuracy);
}
// 主函数
public static void main(String[] args) {
// 读取数据集
ArrayList<ArrayList<Double>> dataSet = Util.loadDataSet("iris.data");
// 打乱数据集顺序
Collections.shuffle(dataSet);
// 获取标签
ArrayList<Integer> labels = new ArrayList<>();
for (ArrayList<Double> data : dataSet) {
if (data.get(data.size() - 1) == 1) {
labels.add(1);
} else {
labels.add(-1);
}
}
// 划分训练集和测试集
ArrayList<ArrayList<Double>> trainData = new ArrayList<>();
ArrayList<ArrayList<Double>> testData = new ArrayList<>();
ArrayList<Integer> trainLabels = new ArrayList<>();
ArrayList<Integer> testLabels = new ArrayList<>();
for (int i = 0; i < dataSet.size(); i++) {
if (i % 5 == 0) {
testData.add(dataSet.get(i));
testLabels.add(labels.get(i));
} else {
trainData.add(dataSet.get(i));
trainLabels.add(labels.get(i));
}
}
// 训练 AdaBoost 分类器
AdaBoost adaBoost = new AdaBoost(trainData, trainLabels, 10);
adaBoost.train();
// 测试分类器
adaBoost.test(testData, testLabels);
}
}
```
需要注意的是,此代码中的 `WeakClassifier` 类是用于实现单个弱分类器的训练和预测的,需要自行实现。同时,数据集的加载和处理部分也需要根据实际情况进行修改。
阅读全文