pytorch的数据集预处理代码
时间: 2023-09-14 10:11:26 浏览: 96
以下是一个简单的PyTorch数据集预处理代码示例,其中包括如何读取数据集、对数据进行转换和标准化、创建数据集迭代器等。
```python
import torch
from torchvision import datasets, transforms
# 数据集的预处理
transform = transforms.Compose([
transforms.Resize((28, 28)), # 调整图像大小为28x28
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 标准化图像像素值
])
# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建数据集迭代器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
```
在这个例子中,我们首先定义了一个转换器`transform`,用于对数据集中的图像进行预处理。然后,我们使用`datasets`模块中的`MNIST`类加载训练集和测试集。最后,我们使用`DataLoader`类创建数据集迭代器,以便能够在训练和测试过程中有效地读取数据。
阅读全文