mnist数据集预处理pytorch
时间: 2023-11-09 21:01:55 浏览: 129
在PyTorch中预处理MNIST数据集的步骤如下:
1. 导入必要的库和模块:
```python
import torch
import torchvision
import torchvision.transforms as transforms
```
2. 定义数据转换:
```python
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
```
这里使用了`transforms.Compose`将多个数据转换操作组合在一起,其中`transforms.ToTensor()`将数据转换为张量,`transforms.Normalize()`对数据进行标准化处理。
3. 加载训练集和测试集:
```python
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False, num_workers=2)
```
这里使用了`torchvision.datasets.MNIST`加载MNIST数据集,并将之前定义的数据转换应用到数据集中。然后使用`torch.utils.data.DataLoader`将数据集转换为可迭代的数据加载器。
4. 查看数据:
```python
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % labels[j] for j in range(64)))
```
这里使用了`matplotlib`库来显示图像,`torchvision.utils.make_grid()`将多张图像拼接在一起,`np.transpose()`将通道维度放到最后。
阅读全文