SVM对sklearn自带手写数字数据集进行分类
时间: 2024-11-20 20:39:12 浏览: 16
SVM(支持向量机)是一种强大的监督学习算法,常用于二分类和多分类问题。在scikit-learn库中,我们可以利用`sklearn.datasets.load_digits()`加载预处理的手写数字数据集,然后使用`SVC` (Support Vector Classifier) 类来进行分类。
首先,你需要导入必要的库和模块:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
```
接着,加载数据并分割成训练集和测试集:
```python
digits = datasets.load_digits()
X = digits.data
y = digits.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
```
为了提高SVM的效果,通常会对数值特征进行标准化(如归一化):
```python
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
```
然后创建SVM模型并训练它:
```python
svm_model = SVC(kernel='linear' or 'rbf', C=1) # 根据需求选择核函数和正则化参数
svm_model.fit(X_train, y_train)
```
最后,用训练好的模型进行预测,并计算准确率:
```python
y_pred = svm_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
```
阅读全文