Fashion-MNIST数据集的下载与读取
时间: 2024-06-09 14:11:51 浏览: 315
Fashion-MNIST数据集是一个用于图像分类任务的数据集,包含了10个类别的70,000张28x28的灰度图像。下面是下载和读取Fashion-MNIST数据集的示例代码:
下载数据集:
```python
import urllib.request
import os
url_train = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz'
url_train_label = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz'
url_test = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz'
url_test_label = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz'
os.makedirs('./data/fashion_mnist', exist_ok=True)
urllib.request.urlretrieve(url_train, './data/fashion_mnist/train-images-idx3-ubyte.gz')
urllib.request.urlretrieve(url_train_label, './data/fashion_mnist/train-labels-idx1-ubyte.gz')
urllib.request.urlretrieve(url_test, './data/fashion_mnist/t10k-images-idx3-ubyte.gz')
urllib.request.urlretrieve(url_test_label, './data/fashion_mnist/t10k-labels-idx1-ubyte.gz')
```
读取数据集:
```python
import gzip
import numpy as np
def load_mnist_images(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
return data.reshape(-1, 28, 28)
def load_mnist_labels(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
return data
train_images = load_mnist_images('./data/fashion_mnist/train-images-idx3-ubyte.gz')
train_labels = load_mnist_labels('./data/fashion_mnist/train-labels-idx1-ubyte.gz')
test_images = load_mnist_images('./data/fashion_mnist/t10k-images-idx3-ubyte.gz')
test_labels = load_mnist_labels('./data/fashion_mnist/t10k-labels-idx1-ubyte.gz')
```
这里的`load_mnist_images`和`load_mnist_labels`函数用于读取数据集文件,并将其转换为NumPy数组。`train_images`和`test_images`是形状为`(60000, 28, 28)`和`(10000, 28, 28)`的数组,表示训练集和测试集的图像数据,每张图像的大小为28x28。`train_labels`和`test_labels`是形状为`(60000,)`和`(10000,)`的数组,表示训练集和测试集的标签数据。
阅读全文