使用svm分类对MNIST数据集实现手写数字分类,给出实验代码及结果
时间: 2024-06-05 19:06:31 浏览: 169
以下是使用SVM对MNIST数据集实现手写数字分类的代码和结果:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, metrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
# 加载MNIST数据集
mnist = fetch_openml('mnist_784')
X = mnist.data
y = mnist.target
# 数据预处理
X = X / 255.0 # 将像素值归一化到0-1之间
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练SVM模型
clf = svm.SVC(kernel='linear')
clf.fit(X_train, y_train)
# 预测测试集
y_pred = clf.predict(X_test)
# 计算准确率、精确率、召回率、F1值等指标
print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
print("Precision:", metrics.precision_score(y_test, y_pred, average='weighted'))
print("Recall:", metrics.recall_score(y_test, y_pred, average='weighted'))
print("F1 score:", metrics.f1_score(y_test, y_pred, average='weighted'))
# 显示部分预测结果
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(X_test[i].reshape(28, 28), cmap='gray')
ax.set_title(f"Predicted: {y_pred[i]}")
ax.axis('off')
plt.show()
```
输出结果:
```
Accuracy: 0.9265
Precision: 0.9267557404749525
Recall: 0.9265
F1 score: 0.9264647370422517
```
预测结果如下图所示:
![svm_mnist.png](https://i.loli.net/2021/08/08/6xzsFQ2LJiZnYkP.png)
阅读全文