使用Java编写一个xgboost4j库全部参数设置的示例
时间: 2024-09-11 13:14:32 浏览: 13
xgboost4j是一个Java库,用于实现XGBoost机器学习算法。以下是一个使用xgboost4j库设置全部参数的示例:
```java
import ml.dmlc.xgboost4j.java.*;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoost;
import java.util.HashMap;
import java.util.Map;
public class XGBoostExample {
public static void main(String[] args) throws XGBoostError {
// 准备训练数据和标签
float[][] trainData = new float[][]{{1, 2}, {3, 4}, {5, 6}};
float[] trainLabels = new float[]{0, 1, 0};
DMatrix trainMat = new DMatrix(trainData, trainLabels);
// 设置XGBoost参数
Map<String, Object> params = new HashMap<>();
params.put("eta", 0.3);
params.put("max_depth", 2);
params.put("objective", "binary:logistic");
params.put("nthread", 2);
params.put("eval_metric", "logloss");
// 训练模型
Booster booster = XGBoost.train(trainMat, params, 10, new HashMap<>());
// 保存模型
booster.saveModel("xgboost.model");
// 加载模型
Booster loadedBooster = XGBoost.loadModel("xgboost.model");
// 预测新数据
float[][] testData = new float[][]{{7, 8}, {9, 10}};
DMatrix testMat = new DMatrix(testData);
float[][] predictions = loadedBooster.predict(testMat);
// 输出预测结果
for (float[] prediction : predictions) {
System.out.println("Prediction: " + prediction[0]);
}
}
}
```
在这个示例中,我们首先导入了所需的xgboost4j库类。然后,我们创建了一个包含训练数据的DMatrix对象,并设置了XGBoost的参数。接下来,我们使用`XGBoost.train()`方法训练模型,并将训练好的模型保存到文件。最后,我们加载模型并对新数据进行预测。