使用【python】写【svm】的代码
时间: 2024-10-04 21:01:57 浏览: 27
当然,这里有一个使用Python的Scikit-learn库编写支持向量机(SVM)的基本示例。我们将继续使用iris数据集来说明:
```python
# 导入所需的库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import accuracy_score
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建SVM分类器,选择线性内核
svm_model = svm.SVC(kernel='linear')
# 使用训练数据拟合模型
svm_model.fit(X_train, y_train)
# 进行预测
y_pred = svm_model.predict(X_test)
# 计算并打印预测精度
accuracy = accuracy_score(y_test, y_pred)
print(f"SVM模型的预测精度: {accuracy * 100:.2f}%")
# 可选:可视化决策边界
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
fig, ax = plt.subplots()
ax.set_title('Support Vector Machine (Linear Kernel)')
ax.set_xlabel('Sepal length')
ax.set_ylabel('Sepal width')
for i, color in zip(range(len(iris.target_names)), cmap_light):
indices = np.where(y == i)
ax.scatter(X[indices, 0], X[indices, 1], c=color, label=iris.target_names[i],
alpha=.8)
ax.legend(loc='upper left')
plt.tight_layout()
plt.show()
```
这个例子展示了如何加载数据、划分训练集和测试集,然后创建并训练一个线性SVM模型,最后进行预测和评估。
阅读全文