请用代码实现pytorch获取mnist训练集,并将其分为训练集和验证集
时间: 2024-04-05 17:33:36 浏览: 8
当然可以!下面是使用PyTorch获取MNIST训练集并将其分为训练集和验证集的代码:
```python
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split
# 定义数据预处理的转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# 将训练集分割为训练集和验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
# 加载测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
```
上面的代码使用了torchvision库中的`datasets.MNIST`来下载和加载MNIST数据集。然后使用`random_split`函数将训练集分割为训练集和验证集。最后,使用`DataLoader`来创建数据加载器,方便后续的训练和验证过程。