写一段SVM通过高斯基核函数进行二分类的python代码,计算识别率并将结果可视化
时间: 2024-05-22 16:12:04 浏览: 93
以下是SVM通过高斯基核函数进行二分类的python代码,使用sklearn库中的SVC模型:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
# 加载数据集
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
# 将标签为0和1的样本挑出来
X = X[y != 2]
y = y[y != 2]
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# 构建SVM模型
clf = svm.SVC(kernel='rbf', gamma=0.7)
clf.fit(X_train, y_train)
# 计算预测准确率
accuracy = clf.score(X_test, y_test)
print("Accuracy:", accuracy)
# 绘制决策边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.title('SVM with RBF Kernel')
plt.show()
```
代码中使用了鸢尾花数据集的前两个特征,将标签为0和1的样本挑出来作为二分类任务的数据。使用SVC模型,指定核函数为高斯基核函数,并设置gamma参数为0.7。计算预测准确率并将决策边界可视化。运行代码后,可以得到如下图所示的结果:
![svm_rbf_kernel](https://img-blog.csdnimg.cn/20210622172823529.png)
阅读全文