SVM算法对非线性数据多分类并可视化多个决策边界的python代码
时间: 2023-12-03 18:42:10 浏览: 29
下面是使用SVM算法对非线性数据进行多分类,并可视化多个决策边界的Python代码。
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
# 生成非线性数据
X, y = datasets.make_classification(n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1, n_classes=3)
# 创建SVM分类器
clf = svm.SVC(kernel='rbf', C=1.0, gamma=0.1)
# 拟合数据
clf.fit(X, y)
# 绘制决策边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# 绘制分类边界和训练点
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.show()
```
在这个例子中,我们使用了`make_classification`函数生成了一个包含三个类别的非线性数据集。然后我们创建了一个SVM分类器,并使用`fit`方法拟合数据。最后,我们使用`predict`方法预测网格中的点,并将预测结果可视化为决策边界。