python实现一个基于k-means算法的mnist手写数字图像分类器
时间: 2023-09-26 16:08:01 浏览: 154
以下是基于k-means算法的mnist手写数字图像分类器的 Python 实现:
首先,我们需要导入所需的库和数据集:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.cluster import KMeans
# 导入mnist数据集
mnist = fetch_openml('mnist_784')
X = mnist.data
y = mnist.target.astype(np.int)
```
然后,我们可以进行数据预处理,以便将像素值缩放到 [0, 1] 的范围内,并将图像大小调整为 28x28:
```python
# 数据预处理
X = X / 255.0
X = X.reshape(-1, 28, 28)
```
接下来,我们需要使用 k-means 算法对数据进行聚类。我们将每个图像视为一个样本,将其所有像素作为特征,并使用 k-means 算法将它们划分为 10 个簇(对应于 10 个数字类别):
```python
# 使用k-means算法将数据聚类为10个簇
X_flat = X.reshape(-1, 28*28)
kmeans = KMeans(n_clusters=10, random_state=42)
kmeans.fit(X_flat)
```
然后,我们可以使用 k-means 模型对测试数据进行预测:
```python
# 对测试数据进行预测
test_idx = np.random.randint(0, len(X), size=10)
test_images = X[test_idx].reshape(-1, 28*28)
pred_labels = kmeans.predict(test_images)
```
最后,我们可以可视化一些测试图像及其预测标签:
```python
# 可视化预测结果
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, idx in enumerate(test_idx):
img = X[idx]
ax[i//5, i%5].imshow(img, cmap='gray')
ax[i//5, i%5].set_title('Predicted Label: {}'.format(pred_labels[i]))
plt.show()
```
完整的代码如下所示:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.cluster import KMeans
# 导入mnist数据集
mnist = fetch_openml('mnist_784')
X = mnist.data
y = mnist.target.astype(np.int)
# 数据预处理
X = X / 255.0
X = X.reshape(-1, 28, 28)
# 使用k-means算法将数据聚类为10个簇
X_flat = X.reshape(-1, 28*28)
kmeans = KMeans(n_clusters=10, random_state=42)
kmeans.fit(X_flat)
# 对测试数据进行预测
test_idx = np.random.randint(0, len(X), size=10)
test_images = X[test_idx].reshape(-1, 28*28)
pred_labels = kmeans.predict(test_images)
# 可视化预测结果
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, idx in enumerate(test_idx):
img = X[idx]
ax[i//5, i%5].imshow(img, cmap='gray')
ax[i//5, i%5].set_title('Predicted Label: {}'.format(pred_labels[i]))
plt.show()
```
运行该代码,即可得到基于 k-means 算法的 mnist 手写数字图像分类器的预测结果。
阅读全文