如何用Python通过支持向量机(SVM)构建并可视化一个手写数字识别分类器,包括预处理阶段的数据可视化以及训练后评估的混淆矩阵展示?
时间: 2024-12-23 18:27:01 浏览: 11
基于SVM的手写数字分类 (python实现)代码+数据.rar
使用Python通过支持向量机(SVM)构建手写数字识别分类器通常涉及以下步骤:
1. **数据加载与预处理**:
- 导入必要的库,如`numpy`, `matplotlib`, `scikit-learn`(包括`svm`模块)等。
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
```
- 加载MNIST手写数字数据集:
```python
digits = load_digits()
X = digits.data
y = digits.target
```
- 对数据进行预处理,通常会对输入特征进行标准化(归一化):
```python
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
```
2. **数据分割与可视化**:
- 划分训练集和测试集:
```python
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
```
- 可视化原始图像示例(例如前5张):
```python
fig, axs = plt.subplots(1, 5, figsize=(10, 2))
for i, ax in enumerate(axs):
ax.imshow(X[i].reshape((8, 8)), cmap='gray')
ax.set_title(f"Digit {y[i]}")
plt.show()
```
3. **模型训练**:
- 创建SVM分类器,并选择合适的核函数(如线性、多项式或径向基函数RBF):
```python
from sklearn.svm import SVC
classifier = SVC(kernel='rbf', C=1)
classifier.fit(X_train, y_train)
```
4. **预测与评估**:
- 预测测试集结果:
```python
y_pred = classifier.predict(X_test)
```
- 计算准确率:
```python
accuracy = classifier.score(X_test, y_test)
print(f"Accuracy: {accuracy * 100:.2f}%")
```
5. **混淆矩阵**:
- 生成混淆矩阵并可视化:
```python
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
cm = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cm, index=digits.target_names, columns=digits.target_names)
plt.figure(figsize=(6, 6))
sns.heatmap(df_cm, annot=True, fmt="d")
plt.title("Confusion Matrix")
plt.show()
```
阅读全文