采用不同的SVM 核函数对多种类型数据集进行二分类,使用 scikit-learn 中提供的样本生成器 make_blobs、make_classification、make_moons、 make_circles 生成一系列线性或非线性可分的二类别数据;分别将 SVM 中四种核函数(线性核、多项式核、高斯核、S 形核)用于上述四种数据集;通过散点图可视化数据样本,并画出 SVM 模型的决策边界,通过模型评价分类准确率。
时间: 2023-12-20 15:04:20 浏览: 83
使用SVM做分类
好的,下面是针对您的问题,使用 scikit-learn 中提供的样本生成器 make_blobs、make_classification、make_moons、 make_circles 生成一系列线性或非线性可分的二类别数据,并分别将 SVM 中四种核函数(线性核、多项式核、高斯核、S 形核)用于上述四种数据集;通过散点图可视化数据样本,并画出 SVM 模型的决策边界,通过模型评价分类准确率的代码实现:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs, make_classification, make_moons, make_circles
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
# 定义 SVM 模型和核函数
model_linear = SVC(kernel='linear')
model_poly = SVC(kernel='poly', degree=2)
model_rbf = SVC(kernel='rbf')
model_sigmoid = SVC(kernel='sigmoid')
# 定义数据集生成器和相关参数
dataset_funcs = [make_blobs, make_classification, make_moons, make_circles]
kernel_funcs = [model_linear, model_poly, model_rbf, model_sigmoid]
dataset_names = ['make_blobs', 'make_classification', 'make_moons', 'make_circles']
# 循环遍历数据集,分别训练 SVM 模型并绘制决策边界
for i, dataset_func in enumerate(dataset_funcs):
X, y = dataset_func(n_samples=200, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
plt.figure(figsize=(10, 8))
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', alpha=0.5)
# 循环遍历核函数,分别训练 SVM 模型并绘制决策边界
for j, kernel_func in enumerate(kernel_funcs):
model_name = kernel_func.kernel
plt.subplot(2, 2, j+1)
plt.title(model_name)
kernel_func.fit(X_train, y_train)
y_pred = kernel_func.predict(X_test)
acc = accuracy_score(y_test, y_pred)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', alpha=0.5)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='x')
xx, yy = np.meshgrid(np.linspace(X[:, 0].min()-0.5, X[:, 0].max()+0.5, 100),
np.linspace(X[:, 1].min()-0.5, X[:, 1].max()+0.5, 100))
Z = kernel_func.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contour(xx, yy, Z, colors='k', levels=[-1, 0, 1], alpha=0.5)
plt.text(xx.max()-0.3, yy.min()+0.3, f"Accuracy: {acc:.2f}")
plt.suptitle(f"{dataset_names[i]} dataset")
plt.show()
```
这段代码首先定义了 SVM 模型和核函数,并定义了数据集生成器和相关参数。然后循环遍历数据集和核函数,分别训练 SVM 模型并绘制决策边界。在每个数据集中,使用散点图可视化数据样本,并循环遍历核函数,分别训练 SVM 模型并绘制决策边界。最后输出模型的分类准确率。
阅读全文