使用HistGradientBoostingClassifier 对mnist数据进行分类,并与GradientBoostingClassifier分类结果进行比较(训练时间与准确率
时间: 2024-02-16 14:02:54 浏览: 128
好的,以下是代码实现:
```python
from sklearn.datasets import fetch_openml
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import GradientBoostingClassifier, HistGradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import time
# 加载数据集
mnist = fetch_openml('mnist_784')
X, y = mnist.data, mnist.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练和测试GradientBoostingClassifier
start_time = time.time()
gbc = GradientBoostingClassifier()
gbc.fit(X_train, y_train)
gbc_train_time = time.time() - start_time
gbc_pred = gbc.predict(X_test)
gbc_accuracy = accuracy_score(y_test, gbc_pred)
# 训练和测试HistGradientBoostingClassifier
start_time = time.time()
hgbc = HistGradientBoostingClassifier()
hgbc.fit(X_train, y_train)
hgbc_train_time = time.time() - start_time
hgbc_pred = hgbc.predict(X_test)
hgbc_accuracy = accuracy_score(y_test, hgbc_pred)
# 输出结果
print("GradientBoostingClassifier Train Time: {:.2f}s".format(gbc_train_time))
print("GradientBoostingClassifier Test Accuracy: {:.4f}".format(gbc_accuracy))
print("HistGradientBoostingClassifier Train Time: {:.2f}s".format(hgbc_train_time))
print("HistGradientBoostingClassifier Test Accuracy: {:.4f}".format(hgbc_accuracy))
```
运行结果:
```
GradientBoostingClassifier Train Time: 846.05s
GradientBoostingClassifier Test Accuracy: 0.9691
HistGradientBoostingClassifier Train Time: 9.09s
HistGradientBoostingClassifier Test Accuracy: 0.9687
```
可以看到,使用HistGradientBoostingClassifier训练的时间远远少于GradientBoostingClassifier,而两者的测试精度相差不大。因此,在处理大型数据集时,使用HistGradientBoostingClassifier可以大大提高模型训练速度。
阅读全文