xgboost的三分类代码
时间: 2024-08-14 14:04:14 浏览: 45
XGBoost是一个高效的梯度提升库,常用于机器学习任务。下面是一个简单的Python示例,展示如何使用XGBoost进行三分类(Multi-class classification):
```python
# 导入必要的库
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据集(这里以鸢尾花为例)
iris = load_iris()
X = iris.data
y = iris.target
# 将类别标签转换为数值编码(假设三个类别分别为0, 1, 2)
le = preprocessing.LabelEncoder()
y_encoded = le.fit_transform(y)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
# 定义DMatrix作为XGBoost的数据结构
dtrain = xgb.DMatrix(data=X_train, label=y_train)
dtest = xgb.DMatrix(data=X_test, label=y_test)
# 参数设置(你可以根据需要调整参数)
param = {
'objective': 'multi:softmax', # 使用多类概率输出
'num_class': 3, # 有三个类别
'eval_metric': ['mlogloss'], # 使用对数损失评估模型
}
# 训练模型
num_round = 100
bst = xgb.train(param, dtrain, num_round)
# 预测
preds = bst.predict(dtest)
# 解码预测结果回原类别
y_pred = le.inverse_transform(preds.argmax(axis=1))
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
```
在这个例子中,我们首先加载了一个预装的数据集(如鸢尾花),然后将其分为训练集和测试集。接下来,我们将类别标签编码,以便XGBoost可以处理。接着,我们创建了`DMatrix`对象,并设置了相应的参数。最后,我们训练模型并进行预测,计算预测的准确性。
阅读全文