python实现用SVM模型对sklearn数据库中的红酒数据集进行分类,通过五折交叉验证和网格搜索法确定最佳参数,再计算准确率、精度、F1_score和召回率;最后在同一张图上绘制三类分别的ROC和整体的macro ROC
时间: 2024-11-12 17:37:32 浏览: 30
在Python中,我们可以使用scikit-learn库来实现这个任务。以下是一步步的步骤:
1. 导入必要的模块:
```python
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, roc_curve, auc, plot_roc_curve
import matplotlib.pyplot as plt
```
2. 加载红酒数据集并预处理:
```python
wine = datasets.load_wine()
X = wine.data
y = wine.target
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
```
3. 划分训练集和测试集(例如70%训练,30%测试):
```python
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)
```
4. 使用GridSearchCV确定SVM的最佳参数:
```python
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf'], 'gamma': ['scale', 'auto']}
svc = SVC()
grid_search = GridSearchCV(svc, param_grid, cv=5)
grid_search.fit(X_train, y_train)
best_params = grid_search.best_params_
print(f"Best parameters: {best_params}")
```
5. 训练模型并评估性能:
```python
model = SVC(**best_params)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')
print(f"Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, F1 Score: {f1:.2f}, Recall: {recall:.2f}")
# 可能需要导入classification_report以获得详细的报告
```
6. 计算并绘制ROC曲线:
```python
y_scores = model.predict_proba(X_test)[:, 1] # 获取概率值
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(3): # 对于红酒数据集的三类
fpr[i], tpr[i], _ = roc_curve(y_test == i, y_scores[y_test == i])
roc_auc[i] = auc(fpr[i], tpr[i])
# 绘制平均和每类的ROC曲线
mean_fpr = np.linspace(0, 1, 100)
all_tpr = np.zeros_like(mean_fpr)
for i, (fpr_, tpr_) in enumerate(zip(fpr.values(), tpr.values())):
all_tpr += interp(mean_fpr, fpr_, tpr_)
roc_auc['micro'] = auc(mean_fpr, all_tpr / 3.)
plot_roc_curve(model, X_test, y_test)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8)
plt.plot(mean_fpr, all_tpr / 3., label=f'Micro-average ROC curve (area = {roc_auc["micro"] :.2f})',
color='b', linewidth=2, alpha=.8)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.show()
```
这将显示三类的ROC曲线以及总体的macro ROC曲线,以及相应的AUC分数。
阅读全文