torch.datasets中MNIST数据集如何作为训练集和测试集
时间: 2024-05-09 11:14:59 浏览: 112
在PyTorch中,MNIST数据集可以通过torchvision.datasets模块来获取。可以使用`train=True`来获取训练集,使用`train=False`来获取测试集。以下是一个示例代码,演示如何加载MNIST数据集并将其划分为训练集和测试集:
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据转换
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# 加载MNIST数据集,划分为训练集和测试集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False, num_workers=2)
```
在这个例子中,我们首先定义了一个数据转换,将图像转换为张量,并将其标准化为均值为0.5,标准差为0.5的范围内。然后,我们使用`torchvision.datasets.MNIST`函数来加载MNIST数据集,并将其划分为训练集和测试集。最后,我们使用`torch.utils.data.DataLoader`函数定义了训练集和测试集的数据加载器,以便我们可以在模型训练时使用它们。
阅读全文