写一下svm鸢尾花实验的代码
时间: 2024-12-24 18:20:36 浏览: 3
在Python中使用Scikit-learn库来做SVM鸢尾花分类的示例代码如下:
```python
# 导入必要的库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 数据预处理:归一化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# 创建并训练SVM模型(默认的是RBF核)
clf = SVC(kernel='rbf', C=1, gamma='scale') # 这里可以尝试不同的参数组合
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 评估性能
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
# 可视化决策边界(仅适用于二分类问题)
if len(np.unique(y)) == 2:
from matplotlib.colors import ListedColormap
X_combined = np.vstack((X_train, X_test))
y_combined = np.hstack((y_train, y_test))
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
plt.figure(figsize=(10, 8))
for idx, cl in enumerate(np.unique(y)):
plt.scatter(x=X_combined[y_combined == cl, 0], y=X_combined[y_combined == cl, 1],
alpha=0.8, c=cmap_bold(idx), edgecolor='black',
label=f'{cl}')
plt.title("SVM Decision Boundaries")
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend(loc='upper left')
plt.show()
```
阅读全文