编写一个用Python实现支持向量机分类
时间: 2024-09-08 11:01:09 浏览: 49
支持向量机(SVM)是一种常见的监督学习方法,主要用于分类问题。SVM的核心思想是找到一个超平面,将不同类别的数据分割开来,同时使得分割的间隔(即支持向量与超平面的距离)最大化。在Python中,可以使用机器学习库scikit-learn来实现SVM分类。
以下是一个简单的Python代码示例,展示如何使用scikit-learn库中的SVM分类器对数据集进行分类:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
# 加载数据集,以鸢尾花数据集为例
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 数据标准化,有助于提高SVM分类器的性能
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 创建SVM分类器实例,这里使用默认的RBF核函数
svm_classifier = SVC(kernel='rbf', gamma='auto')
# 训练模型
svm_classifier.fit(X_train, y_train)
# 进行预测
predictions = svm_classifier.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, predictions)
print(f'模型准确率为: {accuracy:.2f}')
# 使用训练好的模型进行预测
# example_data = [[5.1, 3.5, 1.4, 0.2]]
# prediction = svm_classifier.predict(example_data)
# print(f'预测结果: {prediction[0]}')
```
上述代码包含了以下步骤:
1. 导入必要的库。
2. 加载并准备数据集,这里使用了鸢尾花(Iris)数据集。
3. 划分数据集为训练集和测试集。
4. 对数据进行标准化处理。
5. 创建SVM分类器实例并使用RBF核函数。
6. 训练模型。
7. 对测试集进行预测并计算准确率。
请注意,在实际应用中,可能需要对SVM分类器的参数进行调整,以获得更好的性能。例如,选择不同的核函数、调整`C`和`gamma`参数等。
阅读全文