GridSearchCV中使用混淆矩阵代码
时间: 2023-06-10 21:07:48 浏览: 137
GridSearchCV是一个用于超参数调优的工具,它可以帮助我们自动地搜索最佳超参数的组合。在使用GridSearchCV进行模型训练时,我们可以通过设置参数scoring来选择模型评估方法,其中可以选择使用混淆矩阵进行评估。
下面是一个使用GridSearchCV和混淆矩阵进行模型训练和评估的示例代码:
```python
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_digits
# 加载手写数字数据集
digits = load_digits()
# 定义SVC模型和超参数搜索范围
svc = SVC()
param_grid = {
'C': [0.1, 1, 10],
'kernel': ['linear', 'rbf']
}
# 定义混淆矩阵打印函数
def print_confusion_matrix(y_true, y_pred):
cm = confusion_matrix(y_true, y_pred)
print('Confusion matrix:\n{}'.format(cm))
# 使用GridSearchCV进行模型训练和超参数搜索
grid_search = GridSearchCV(svc, param_grid=param_grid, scoring='accuracy', cv=5)
grid_search.fit(digits.data, digits.target)
# 打印最佳超参数和最佳得分
print('Best parameters: {}'.format(grid_search.best_params_))
print('Best score: {:.2f}'.format(grid_search.best_score_))
# 计算测试集上的预测结果并打印混淆矩阵
y_pred = grid_search.predict(digits.data)
print_confusion_matrix(digits.target, y_pred)
```
在上述示例代码中,我们首先加载手写数字数据集,然后定义了一个SVC模型和超参数搜索范围。接着,我们定义了一个打印混淆矩阵的函数print_confusion_matrix。
然后,我们使用GridSearchCV进行模型训练和超参数搜索,其中设置了参数scoring为'accuracy'来进行模型评估。最后,我们使用训练好的模型对测试集进行预测,并调用print_confusion_matrix函数来打印测试集上的混淆矩阵。
阅读全文