kmeans聚类算法解决mnist代码
时间: 2023-11-06 20:06:52 浏览: 137
Kmeans聚类算法可以用于解决MNIST数据集的问题。MNIST数据集是一个手写数字的图像数据集,其中包含了60000个训练样本和10000个测试样本。Kmeans聚类算法可以将这些样本聚类成k个簇,其中k是用户给定的。下面是Kmeans聚类算法解决MNIST数据集的代码:
```
from sklearn.cluster import KMeans
from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
y = digits.target
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(X)
for i in range(10):
print("Cluster ", i)
print(y[clusters == i])
```
在这个代码中,我们首先导入了KMeans算法和MNIST数据集。然后,我们将MNIST数据集中的图像数据存储在X中,将标签存储在y中。接着,我们创建了一个KMeans对象,将其聚类数设置为10,并使用fit_predict方法对数据进行聚类。最后,我们将每个簇中的标签打印出来。
相关问题
kmeans聚类算法案例实现mnist
### K-Means聚类算法实现MNIST手写数字识别
#### 数据准备与导入库
为了使用K-Means聚类算法处理MNIST数据集,首先需要安装并加载必要的Python包。这些工具可以帮助完成数据分析、机器学习建模以及图形展示的任务。
```python
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
```
#### 加载MNIST数据集
通过`fetch_openml`函数可以方便地下载MNIST数据集,并对其进行初步预处理以便后续操作。
```python
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"]
y = y.astype(np.int8) # 将标签转换成整数类型
```
#### 应用降维技术PCA
考虑到原始图片维度较高(每张图像是28×28像素),采用主成分分析(Principal Component Analysis, PCA)来降低特征空间的复杂度,从而提高计算效率和模型性能。
```python
pca = PCA(n_components=0.95) # 设置保留95%方差的比例
reduced_X = pca.fit_transform(X / 255.) # 归一化输入数据后再做PCA变换
print(f"Reduced dimensions to {reduced_X.shape[1]} from original 784.")
```
#### 构建与训练K-Means模型
创建一个具有指定数量簇(cluster)的K-Means实例对象,并调用`.fit()`方法执行实际的学习过程;这里假设已知类别数目为10个即代表十个阿拉伯数字字符。
```python
n_clusters = len(np.unique(y))
model = KMeans(n_clusters=n_clusters, random_state=42)
model.fit(reduced_X[:6000]) # 只选取部分样本来加快速度
```
#### 预测及评估效果
利用训练好的模型对测试集中剩余的数据点做出预测,进而统计各类别的分配情况及其准确性指标。
```python
predicted_labels = model.predict(reduced_X[6000:])
cm = confusion_matrix(y[6000:], predicted_labels)
def find_nearest_label(kmeans_model, actual_labels):
"""辅助函数用于匹配聚类中心到真实标签"""
centers = kmeans_model.cluster_centers_
closest, _ = pairwise_distances_argmin_min(centers, reduced_X[6000:])
mapping = dict(zip(range(len(closest)), [actual_labels[i] for i in closest]))
return mapping
label_map = find_nearest_label(model, y[6000:])
mapped_predictions = np.array([label_map[prediction] for prediction in predicted_labels])
accuracy = accuracy_score(y[6000:], mapped_predictions)
print(f'Accuracy after matching clusters with true labels: {accuracy:.3f}')
```
#### 结果可视化
最后一步是对最终得到的结果进行直观呈现,比如绘制混淆矩阵热力图或者散点分布图等。
```python
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(set(y)))
plt.xticks(tick_marks, set(y), rotation=45)
plt.yticks(tick_marks, set(y))
fmt = 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# 绘制前两个主成分上的样本投影
plt.figure(figsize=(8, 6))
colors = ['navy', 'turquoise', 'darkorange']
lw = 2
for color, i, target_name in zip(colors, [0, 1, 2], ["Cluster 0", "Cluster 1", "Cluster 2"]):
plt.scatter(reduced_X[y == list(label_map.keys())[i], 0],
reduced_X[y == list(label_map.keys())[i], 1],
color=color,
alpha=.8,
lw=lw,
label=target_name)
plt.legend(loc='best', shadow=False, scatterpoints=1)
plt.title('PCA of MNIST Dataset')
plt.show()
```
kmeans聚类算法python 使用mnist数据
好的,我可以为您提供关于使用k-means算法对MNIST数据集进行聚类的Python代码。
首先,您需要导入所需的库:
```python
from sklearn.cluster import KMeans
from sklearn.datasets import fetch_openml
import numpy as np
import matplotlib.pyplot as plt
```
然后,您可以使用fetch_openml函数从OpenML获取MNIST数据集:
```python
mnist = fetch_openml('mnist_784')
X = mnist.data.astype('float32') / 255.0 # 将像素值缩放到0到1之间
y = mnist.target.astype('int64')
```
接下来,您可以使用KMeans算法对MNIST数据进行聚类:
```python
kmeans = KMeans(n_clusters=10, random_state=42)
kmeans.fit(X)
```
最后,您可以可视化聚类结果:
```python
fig, axes = plt.subplots(2, 5, figsize=(8, 3))
for i, ax in enumerate(axes.flat):
center = kmeans.cluster_centers_[i]
ax.imshow(center.reshape(28, 28), cmap='binary')
ax.set_title(f'Cluster {i}')
ax.axis('off')
plt.show()
```
这将显示10个聚类中心,每个聚类中心表示一组相似的数字图像。
希望这可以帮助您开始使用KMeans算法对MNIST数据进行聚类。
阅读全文
相关推荐
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)