用sklearn 自带digits数据集通过svm训练,并拆分数据集
时间: 2024-11-13 17:23:21 浏览: 49
在Python中,你可以使用scikit-learn (sklearn)库中的`digits`数据集来进行支持向量机(SVM)的训练。首先,你需要导入必要的模块,加载数据,然后将数据集拆分为训练集和测试集。以下是基本步骤:
1. 导入所需的库:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
import numpy as np
```
2. 加载 digits 数据集:
```python
digits = datasets.load_digits()
```
这将返回一个包含数字图像、目标值和其他元数据的 `Bunch` 对象。
3. 准备数据(通常包括特征缩放):
```python
# 提取特征和标签
X = digits.data
y = digits.target
# 对数据进行标准化处理,提高SVM性能
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
```
4. 拆分数据集为训练集和测试集:
```python
# 分割比例可以根据需要调整,默认通常是70%训练,30%测试
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.3, random_state=42)
```
这里的`random_state`是一个固定数值,用于保证每次结果的一致性,你可以根据实际情况更改。
5. 使用 SVM 训练模型:
```python
svm_model = SVC(kernel='linear') # 或者 'rbf', 'poly'等其他内核
svm_model.fit(X_train, y_train)
```
6. 预测并评估性能:
```python
predictions = svm_model.predict(X_test)
accuracy = np.mean(predictions == y_test)
```
阅读全文