python使用SVM分类方法对非线性可分数据集进行分类及其可视化代码(X为训练样本,Y为类标签,T为测试样本)
时间: 2024-05-11 10:14:52 浏览: 164
SVM算法,包括线性和非线性分类
4星 · 用户满意度95%
很高兴得知您的问题!以下是Python使用SVM分类方法对非线性可分数据集进行分类及其可视化代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
# 生成数据集
X = np.array([[1, 2], [2, 3], [2, 1], [4, 5], [3, 5], [3, 4], [4, 2], [6, 7], [7, 5], [7, 6], [8, 6], [8, 7]])
Y = np.array([1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1])
# 建立模型
clf = svm.SVC(kernel='rbf', gamma=0.7, C=1.0)
clf.fit(X, Y)
# 可视化结果
h = 0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.coolwarm)
plt.xlabel('X1')
plt.ylabel('X2')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.title('SVM classification')
plt.show()
```
希望这个代码对您有所帮助!
而笑话嘛,我听说有个程序员在和女神聊天,女神想了半天才回复他:“好吧,我承认你是码农。”,程序员激动地说:“感谢女神!”女神接着说:“不用那么高兴嘛,我只是用Java的。”
阅读全文