svm对sklearn自带手写数字数据集进行分类
时间: 2024-11-20 07:41:30 浏览: 13
SVM (支持向量机) 是一种常用的监督学习算法,在sklearn库中可以方便地用于手写数字数据集MNIST的分类任务。MNIST是一个包含60,000个训练样本和10,000个测试样本的手写数字图像数据库。
在sklearn中,你可以按照以下步骤使用SVM对MNIST数据集进行分类:
1. 导入所需的库和模块:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score
```
2. 加载和预处理数据:
```python
# 加载数据
digits = datasets.load_digits()
X = digits.data
y = digits.target
# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
```
3. 创建并训练SVM模型:
```python
svm_model = SVC(kernel='linear') # 或者尝试其他内核如 'poly', 'rbf'
svm_model.fit(X_train, y_train)
```
4. 进行预测并评估性能:
```python
y_pred = svm_model.predict(X_test)
print("Classification Report:")
print(classification_report(y_test, y_pred))
print(f"Accuracy Score: {accuracy_score(y_test, y_pred)}")
```
阅读全文