xgboost预测鸢尾花数据集
时间: 2023-09-04 16:06:32 浏览: 75
基于鸢尾花的数据预测模型
下面是使用 XGBoost 预测鸢尾花数据集的代码示例:
```python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import xgboost as xgb
# 加载数据集
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# 构建 DMatrix 数据格式
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
# 设置参数
param = {
'max_depth': 3, # 树的最大深度
'eta': 0.3, # 学习率
'objective': 'multi:softmax', # 多分类问题
'num_class': 3 # 类别数
}
# 训练模型
num_round = 20
bst = xgb.train(param, dtrain, num_round)
# 预测结果
y_pred = bst.predict(dtest)
# 计算准确率
accuracy = sum(y_test == y_pred) / len(y_test)
print('Accuracy:', accuracy)
```
这里使用了 XGBoost 的 DMatrix 数据格式,它可以将数据集转换成 XGBoost 可以处理的格式。参数中设置了树的最大深度、学习率、目标函数(多分类问题)和类别数,并且训练了 20 次。最后计算了预测准确率并输出。
阅读全文