svm二分类python可视化
时间: 2023-09-30 19:09:33 浏览: 264
要在Python中可视化SVM二分类,可以使用Scikit-learn库中的SVC类和Matplotlib库来实现。下面是一个示例代码来实现这个过程:
首先,确保你已经导入所需的库,包括sklearn.svm.SVC、matplotlib.pyplot和numpy。然后,你需要准备你的数据集。在这个例子中,我们使用的是月亮数据集,可以使用make_moons函数从sklearn.datasets中生成。
接下来,你可以使用Pipeline类和StandardScaler来进行数据预处理和特征缩放。然后,创建一个SVC对象,并使用rbf核函数来实例化它。设置好超参数gamma和C的值。
然后,使用fit方法拟合SVC模型,并将其添加到svm_clfs列表中。
最后,使用matplotlib.pyplot的subplot和plot方法来绘制图形。在每个子图中,使用plot_predictions函数绘制决策边界和数据点,使用plot_dataset函数绘制数据集。为每个子图设置标题,显示超参数的值。
以下是示例代码:
```
import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
import numpy as np
import matplotlib as mpl
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
# 为了显示中文
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
X, y = make_moons(n_samples=100, noise=0.15, random_state=42)
def plot_dataset(X, y, axes):
plt.plot(X[:, 0][y==0], X[:, 1][y==0], "bs")
plt.plot(X[:, 0][y==1], X[:, 1][y==1], "g^")
plt.axis(axes)
plt.grid(True, which='both')
plt.xlabel(r"$x_1$", fontsize=20)
plt.ylabel(r"$x_2$", fontsize=20, rotation=0)
plt.title("月亮数据",fontsize=20)
hyperparams = ((gamma1, C1), (gamma1, C2))
svm_clfs = []
for gamma, C in hyperparams:
rbf_kernel_svm_clf = Pipeline([
("scaler", StandardScaler()),
("svm_clf", SVC(kernel="rbf", gamma=gamma, C=C))
])
rbf_kernel_svm_clf.fit(X, y)
svm_clfs.append(rbf_kernel_svm_clf)
plt.figure(figsize=(11, 7))
for i, svm_clf in enumerate(svm_clfs):
plt.subplot(221 + i)
plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
gamma, C = hyperparams[i]
plt.title(r"$\gamma = {}, C = {}$".format(gamma, C), fontsize=16)
plt.tight_layout()
plt.show()
```
这段代码将生成一个包含四个子图的图形,每个子图显示了不同超参数设置下的SVM二分类器的决策边界和数据点分布。
阅读全文