实现mnist手写数字数据集knn分类并计算正确率
时间: 2023-11-15 17:47:50 浏览: 116
基于knn算法对minist手写数据集的分类
好的,我可以帮您实现这个任务。首先,我们需要加载 MNIST 数据集。您可以使用以下代码来下载和加载数据集:
```python
import numpy as np
import urllib.request
import gzip
# 下载数据集
def download_data(url):
filename = url.split("/")[-1]
urllib.request.urlretrieve(url, filename)
print(f"{filename} has been downloaded")
# 加载数据集
def load_data(filename, num_images, image_size):
with gzip.open(filename, "rb") as f:
f.read(16)
buffer = f.read(image_size * image_size * num_images)
data = np.frombuffer(buffer, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size * image_size)
return data
download_data("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")
download_data("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")
download_data("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")
download_data("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")
X_train = load_data("train-images-idx3-ubyte.gz", 60000, 28)
y_train = load_data("train-labels-idx1-ubyte.gz", 60000, 1).flatten()
X_test = load_data("t10k-images-idx3-ubyte.gz", 10000, 28)
y_test = load_data("t10k-labels-idx1-ubyte.gz", 10000, 1).flatten()
```
接下来,我们可以使用 scikit-learn 中的 KNeighborsClassifier 类来实现 kNN 分类器。以下是完整的代码:
```python
from sklearn.neighbors import KNeighborsClassifier
# 实例化 kNN 分类器
k = 3
knn = KNeighborsClassifier(n_neighbors=k)
# 训练分类器
knn.fit(X_train, y_train)
# 测试分类器并计算正确率
y_pred = knn.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"kNN classifier with k={k} has an accuracy of {accuracy:.2f}")
```
在上面的代码中,我们使用了 k=3 的 kNN 分类器,并在测试集上计算了分类器的正确率。您可以尝试不同的 k 值来看看正确率的变化。
阅读全文