使用torchvision库中的数据模块下载MNIST数据集,并随机展示其中8张图像。
时间: 2024-10-22 11:16:44 浏览: 38
Pytorch使用MNIST数据集实现基础GAN和DCGAN详解
TorchVision是一个PyTorch库的一部分,专门用于处理计算机视觉任务的数据。它提供了一个简单的方式来加载各种常用的数据集,包括MNIST手写数字数据集。首先,你需要安装torchvision库,如果你还没有安装,可以使用pip install torchvision命令。
下面是如何使用torchvision的DataLoader和transforms模块下载并随机显示MNIST数据集中8张图像的步骤:
```python
import torch
from torchvision import datasets, transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化像素值到[0,1]
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# 创建数据加载器,batch_size设为1以便一次查看一张图片
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
# 遍历数据加载器,取出并显示前8张图片
for i, (images, _) in enumerate(dataloader):
if i == 7: # 只显示8张
break
img = images.squeeze() # 去掉通道维度,变成二维数组
plt.imshow(img.numpy(), cmap='gray') # 显示灰度图
plt.show()
# 或者,你也可以使用imshow函数直接从DataLoader中获取数据
for img, _ in dataloader:
img = img.squeeze()
plt.imshow(img.numpy(), cmap='gray')
plt.show()
if i == 7:
break
```
这段代码会从MNIST训练集中随机选取8张图像进行显示。
阅读全文