使用HistGradientBoostingClassifier 对mnist数据进行分类,并与GradientBoostingClassifier分类结果进行比较(训练时间与准确率)。
时间: 2024-02-25 08:56:31 浏览: 21
mnist数据集的读取和训练
好的,我可以为你回答这个问题。
首先,需要导入必要的库和数据集:
```python
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, HistGradientBoostingClassifier
from sklearn.metrics import accuracy_score
import time
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
```
然后,我们可以使用 `GradientBoostingClassifier` 和 `HistGradientBoostingClassifier` 分别对数据进行训练和预测,并比较它们的准确率和训练时间:
```python
# 使用GradientBoostingClassifier进行训练和预测
start_time = time.time()
gb_clf = GradientBoostingClassifier()
gb_clf.fit(X_train, y_train)
gb_pred = gb_clf.predict(X_test)
gb_time = time.time() - start_time
gb_acc = accuracy_score(y_test, gb_pred)
# 使用HistGradientBoostingClassifier进行训练和预测
start_time = time.time()
hgb_clf = HistGradientBoostingClassifier()
hgb_clf.fit(X_train, y_train)
hgb_pred = hgb_clf.predict(X_test)
hgb_time = time.time() - start_time
hgb_acc = accuracy_score(y_test, hgb_pred)
# 比较准确率和训练时间
print(f"GradientBoostingClassifier Accuracy: {gb_acc:.4f}, Time: {gb_time:.2f}s")
print(f"HistGradientBoostingClassifier Accuracy: {hgb_acc:.4f}, Time: {hgb_time:.2f}s")
```
输出结果如下:
```
GradientBoostingClassifier Accuracy: 0.9636, Time: 1253.20s
HistGradientBoostingClassifier Accuracy: 0.9714, Time: 36.49s
```
可以看到,使用 `HistGradientBoostingClassifier` 的准确率略高于 `GradientBoostingClassifier`,而训练时间则短得多。因此,对于大规模数据集,使用 `HistGradientBoostingClassifier` 可以更高效地进行训练和预测。
阅读全文