sklearn网格搜索
时间: 2023-09-02 14:14:00 浏览: 69
Scikit-learn中的网格搜索(GridSearchCV)是一种用于寻找最优超参数的方法,它对每个超参数的可能取值进行交叉验证,并返回最优参数的组合。
使用网格搜索需要以下步骤:
1. 创建参数字典:根据模型需要调整的超参数,创建一个包含所有可能取值的参数字典。
2. 实例化GridSearchCV对象:使用参数字典和模型对象创建GridSearchCV对象。
3. 进行网格搜索:使用fit()方法对数据进行拟合,GridSearchCV对象会自动尝试所有可能的参数组合,返回最佳超参数组合。
4. 输出结果:输出最佳超参数组合和对应的模型指标。
示例代码如下:
``` python
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris
# 创建参数字典
params = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}
# 加载数据集
iris = load_iris()
# 实例化SVM模型
svm = SVC()
# 实例化GridSearchCV对象
grid_search = GridSearchCV(svm, params, cv=5)
# 进行网格搜索
grid_search.fit(iris.data, iris.target)
# 输出结果
print("Best score: %0.3f" % grid_search.best_score_)
print("Best parameters set:")
best_parameters = grid_search.best_estimator_.get_params()
for param_name in sorted(params.keys()):
print("\t%s: %r" % (param_name, best_parameters[param_name]))
```
其中,参数cv表示交叉验证的折数,可以根据需要调整。
阅读全文