SVM超参数对比 python可视化案例
时间: 2023-11-05 20:01:57 浏览: 77
以下是一个SVM超参数对比的Python可视化案例。
首先,我们需要导入必要的库和数据集。在这个案例中,我们将使用Iris数据集。
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
# Load the iris dataset
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
```
接下来,我们定义一个函数来训练SVM,并返回测试集的准确率。
```python
def train_svm(C, gamma):
# Create an SVM classifier
svm = SVC(kernel='rbf', C=C, gamma=gamma)
# Train the classifier on the training data
svm.fit(X_train, y_train)
# Make predictions on the testing data
y_pred = svm.predict(X_test)
# Calculate the accuracy of the classifier
acc = accuracy_score(y_test, y_pred)
return acc
```
现在,我们可以使用这个函数来比较不同的超参数组合,并将结果可视化。
```python
# Define the range of values for C and gamma
C_range = np.logspace(-3, 3, 7)
gamma_range = np.logspace(-3, 3, 7)
# Create a meshgrid of C and gamma values
C, gamma = np.meshgrid(C_range, gamma_range)
# Initialize an array to store the accuracy scores
acc_scores = np.zeros((len(C_range), len(gamma_range)))
# Calculate the accuracy score for each combination of C and gamma
for i in range(len(C_range)):
for j in range(len(gamma_range)):
acc_scores[i, j] = train_svm(C_range[i], gamma_range[j])
# Plot the accuracy scores as a heatmap
plt.figure(figsize=(10, 8))
plt.imshow(acc_scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('gamma')
plt.ylabel('C')
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plt.title('SVM accuracy')
plt.show()
```
这将生成一个热图,显示不同的C和gamma值组合的准确性得分。
![SVM hyperparameter comparison](https://i.imgur.com/7VdWZbQ.png)
从图中可以看出,当C和gamma的值都较小时,准确性得分较低。当C和gamma的值都较大时,准确性得分最高。此外,当C和gamma的值相等时,准确性得分也较高。
阅读全文