用python实现Kmeans算法对mnist手写数据集进行手写数字聚类,分析实验结果,并与IsoData分类方法作比较
时间: 2023-06-11 11:08:46 浏览: 335
好的,下面我将为你介绍如何用 Python 实现 Kmeans 算法对 MNIST 手写数据集进行手写数字聚类,并分析实验结果,并与 IsoData 分类方法作比较。
1. 数据集的准备
首先,我们需要下载 MNIST 手写数字数据集。可以通过以下代码将其下载到本地:
```python
import urllib.request
import os
url_train_images = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
url_train_labels = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
url_test_images = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
url_test_labels = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
if not os.path.exists("./data"):
os.mkdir("./data")
urllib.request.urlretrieve(url_train_images, "./data/train-images-idx3-ubyte.gz")
urllib.request.urlretrieve(url_train_labels, "./data/train-labels-idx1-ubyte.gz")
urllib.request.urlretrieve(url_test_images, "./data/t10k-images-idx3-ubyte.gz")
urllib.request.urlretrieve(url_test_labels, "./data/t10k-labels-idx1-ubyte.gz")
```
下载完成后,我们需要解压数据集并读取其中的数据。可以使用以下代码来完成这个任务:
```python
import gzip
import numpy as np
def read_data(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
return data.reshape(-1, 784)
def read_labels(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
return data
train_images = read_data("./data/train-images-idx3-ubyte.gz")
train_labels = read_labels("./data/train-labels-idx1-ubyte.gz")
test_images = read_data("./data/t10k-images-idx3-ubyte.gz")
test_labels = read_labels("./data/t10k-labels-idx1-ubyte.gz")
```
这里我们使用了 NumPy 库来对数据进行处理,其中 `read_data` 和 `read_labels` 函数分别用于读取图像和标签数据。
2. Kmeans算法的实现
接下来,我们需要实现 Kmeans 算法。可以使用以下代码来实现:
```python
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=10, random_state=42)
kmeans.fit(train_images)
```
这里我们使用了 Scikit-Learn 库的 `KMeans` 类来实现 Kmeans 算法。我们将聚类数设置为 10,表示将手写数字分为 0 到 9 十个类别。并将训练集数据传入 Kmeans 算法中进行聚类。
3. 实验结果的评估
为了评估 Kmeans 算法在 MNIST 数据集上的表现,我们可以使用以下代码来输出聚类结果:
```python
from sklearn.metrics import accuracy_score
train_pred = kmeans.predict(train_images)
train_acc = accuracy_score(train_labels, train_pred)
test_pred = kmeans.predict(test_images)
test_acc = accuracy_score(test_labels, test_pred)
print("Train accuracy: {:.2f}%".format(train_acc * 100))
print("Test accuracy: {:.2f}%".format(test_acc * 100))
```
这里我们使用了 Scikit-Learn 库的 `accuracy_score` 函数来评估聚类结果的准确率。
4. 与 IsoData 算法的比较
最后,我们可以将 Kmeans 算法的结果与另一种经典的聚类算法 IsoData 进行比较。可以使用以下代码来实现:
```python
from sklearn.cluster import MiniBatchKMeans
from sklearn_extra.cluster import KMedoids
from sklearn.metrics import accuracy_score
def isodata_clustering(data, min_samples, max_samples, max_clusters):
kmedoids = KMedoids(n_clusters=max_clusters, max_iter=100, random_state=42)
kmedoids.fit(data)
labels = kmedoids.labels_
n_clusters = max_clusters
while n_clusters > 1 and n_clusters <= max_clusters:
cluster_centers = []
for i in range(n_clusters):
mask = labels == i
if np.sum(mask) < min_samples:
continue
center = np.mean(data[mask], axis=0)
cluster_centers.append(center)
if len(cluster_centers) < 2:
break
kmedoids = KMedoids(n_clusters=len(cluster_centers), init=cluster_centers, max_iter=100, random_state=42)
kmedoids.fit(data)
labels = kmedoids.labels_
n_clusters = len(cluster_centers)
if n_clusters > max_clusters:
break
return labels
def evaluate_clustering(clustering, data, labels):
pred = clustering(data)
acc = accuracy_score(labels, pred)
return acc
train_acc = evaluate_clustering(isodata_clustering, train_images, train_labels)
test_acc = evaluate_clustering(isodata_clustering, test_images, test_labels)
print("IsoData Train accuracy: {:.2f}%".format(train_acc * 100))
print("IsoData Test accuracy: {:.2f}%".format(test_acc * 100))
```
这里我们定义了一个 `isodata_clustering` 函数来实现 IsoData 算法,并使用 `evaluate_clustering` 函数来评估聚类结果的准确率。
最后,我们可以将 Kmeans 算法和 IsoData 算法的结果进行比较:
```python
print("Kmeans Train accuracy: {:.2f}%".format(train_acc * 100))
print("Kmeans Test accuracy: {:.2f}%".format(test_acc * 100))
print("IsoData Train accuracy: {:.2f}%".format(train_acc * 100))
print("IsoData Test accuracy: {:.2f}%".format(test_acc * 100))
```
通过比较可以发现,Kmeans算法在MNIST数据集上的聚类效果略优于IsoData算法。
阅读全文