2.使用torchvision库中的数据模块下载MNIST数据集,并随机展示其中8张图像。 将2下载的图像数据进行随机翻转,随机裁剪,并且转化为统一尺寸大小120x120,最后转换为tensor数据
时间: 2024-10-22 20:17:29 浏览: 80
Pytorch使用MNIST数据集实现基础GAN和DCGAN详解
首先,你需要安装PyTorch及其 torchvision 库,如果你还没有安装,可以使用以下命令进行安装:
```bash
pip install torch torchvision
```
接下来,我们可以使用以下步骤从MNIST数据集中下载并处理数据:
1. 导入所需的库:
```python
import torch
import torchvision
from torchvision import transforms
```
2. 下载MNIST数据集:
```python
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.RandomCrop(120, padding=10), # 随机裁剪至120x120,填充边界像素
transforms.Resize((120, 120)), # 调整到固定大小
transforms.ToTensor(), # 转换为Tensor
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
```
注意:`root='./data'`需要替换为你希望保存数据的实际路径。
3. 打印和显示随机选取的8张图像:
```python
num_samples_to_show = 8
images_to_display = torch.stack([train_dataset[i][0] for i in torch.randperm(len(train_dataset))[:num_samples_to_show]])
# 将灰度图像转换为RGB以便于可视化
images_to_display = images_to_display.repeat(1, 1, 3) if images_to_display.shape[-1] == 1 else images_to_display
imshow_images = torchvision.utils.make_grid(images_to_display, nrow=4)
plt.imshow(torchvision.transforms.functional.to_pil_image(imshow_images))
plt.show()
```
这里我们使用了`torchvision.utils.make_grid`函数将图片合并成网格并显示出来。
阅读全文