用python画出三维svm超平面
时间: 2024-10-27 07:04:39 浏览: 21
SVM基本概念及Python实现代码
在Python中,我们可以使用`matplotlib`库结合`mpl_toolkits.mplot3d`模块来绘制三维SVM超平面。首先,你需要安装`numpy`, `scikit-learn`以及`matplotlib`等库。下面是一个简单的示例,展示如何使用`sklearn`库中的`SVC`模型(支持向量机)和数据可视化:
```python
import numpy as np
from sklearn import svm, datasets
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 加载数据集(这里假设使用的是鸢尾花数据)
iris = datasets.load_iris()
X = iris.data[:, :3] # 取前三个维度作为特征
y = iris.target
# 创建支持向量机分类器
clf = svm.SVC(kernel='linear') # 线性核函数
# 训练模型
clf.fit(X, y)
# 获取决策边界参数
coef = clf.coef_[0] # 权重向量
intercept = clf.intercept_ # 截距
# 定义网格
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 50),
np.linspace(y_min, y_max, 50))
# 计算在网格点处的支持向量机预测值
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# 创建3D图形并画出决策边界
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_zlabel('Feature 3')
ax.plot_surface(xx, yy, Z, rstride=1, cstride=1, alpha=0.6)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap=plt.cmviridis)
plt.show()
阅读全文