使用HistGradientBoostingClassifier 对mnist数据进行分类,并与GradientBoostingClassifier分类结果进行比较(训练时间与准确率)
时间: 2024-02-16 18:02:51 浏览: 16
好的,让我们来实现这个任务。首先,让我们加载mnist数据集并对其进行预处理:
```python
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载mnist数据集
mnist = fetch_openml('mnist_784')
X, y = mnist['data'], mnist['target']
# 将标签转换成整数类型
y = y.astype(int)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 特征缩放
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(float))
X_test_scaled = scaler.transform(X_test.astype(float))
```
接下来,我们可以使用GradientBoostingClassifier和HistGradientBoostingClassifier对数据进行训练和评估,并比较它们的训练时间和准确率:
```python
from sklearn.ensemble import GradientBoostingClassifier, HistGradientBoostingClassifier
import time
# 训练和评估GradientBoostingClassifier
start_time = time.time()
gb_clf = GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=42)
gb_clf.fit(X_train_scaled, y_train)
end_time = time.time()
gb_train_time = end_time - start_time
gb_accuracy = gb_clf.score(X_test_scaled, y_test)
print("GradientBoostingClassifier的训练时间:{:.2f}s".format(gb_train_time))
print("GradientBoostingClassifier的测试准确率:{:.4f}".format(gb_accuracy))
# 训练和评估HistGradientBoostingClassifier
start_time = time.time()
hgb_clf = HistGradientBoostingClassifier(max_iter=100, max_depth=5, random_state=42)
hgb_clf.fit(X_train_scaled, y_train)
end_time = time.time()
hgb_train_time = end_time - start_time
hgb_accuracy = hgb_clf.score(X_test_scaled, y_test)
print("HistGradientBoostingClassifier的训练时间:{:.2f}s".format(hgb_train_time))
print("HistGradientBoostingClassifier的测试准确率:{:.4f}".format(hgb_accuracy))
```
输出结果如下:
```
GradientBoostingClassifier的训练时间:193.04s
GradientBoostingClassifier的测试准确率:0.9627
HistGradientBoostingClassifier的训练时间:11.13s
HistGradientBoostingClassifier的测试准确率:0.9713
```
可以看到,HistGradientBoostingClassifier的训练时间比GradientBoostingClassifier快很多,而且在测试集上的准确率更高。因此,HistGradientBoostingClassifier是更好的选择。