使用torchvision库中的数据模块下载MNIST数据集,并随机展示其中8张图像
时间: 2024-10-22 09:07:19 浏览: 99
在PyTorch中,`torchvision.datasets.MNIST`是一个方便的数据加载工具,用于获取MNIST手写数字数据集。首先,你需要安装`torchvision`,如果还没有安装,可以使用pip安装:
```bash
pip install torchvision
```
接下来,你可以按照以下步骤使用`torchvision`下载并随机显示MNIST数据集中的8张图像:
```python
import torch
from torchvision import datasets, transforms
# 数据预处理,对图像做归一化并将像素值从0-255转为0-1
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 下载和加载MNIST数据集
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# 获取数据集大小
num_images_to_show = min(8, len(mnist_dataset))
# 随机选择8张图像及其对应的标签
selected_indices = torch.randperm(len(mnist_dataset))[:num_images_to_show]
images_and_labels = [(mnist_dataset[i][0], mnist_dataset[i][1]) for i in selected_indices]
# 创建一个新的figure展示这8张图像
import matplotlib.pyplot as plt
fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(10, 5))
for ax, (image, label) in zip(axs.flat, images_and_labels):
ax.imshow(image, cmap='gray')
ax.set_title('Label: {}'.format(label), fontsize=10)
ax.axis('off') # 关闭坐标轴
plt.show()
```
运行这段代码后,你会看到一个包含8张MNIST手写数字的随机图片网格。
阅读全文