mnist数据集代码
时间: 2023-07-02 14:05:51 浏览: 50
以下是使用PyTorch加载MNIST数据集的示例代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# 加载训练集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True)
# 加载测试集
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False)
# 展示数据集中的一些图片
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # 去归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % labels[j] for j in range(4)))
```
运行以上代码,你将得到MNIST数据集中的一些图片。此外,训练集和测试集也已经被加载到`trainloader`和`testloader`中,你可以使用它们来训练或测试你的模型。