pytorch的数据集预处理操作代码
时间: 2023-11-16 10:05:47 浏览: 95
以下是一个示例pytorch数据集预处理操作的代码:
```python
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
# 定义数据集预处理操作
transform = transforms.Compose([
transforms.ToTensor(), # 将数据转换为PyTorch张量
transforms.Normalize((0.1307,), (0.3081,)) # 标准化数据
])
# 加载MNIST数据集
train_data = MNIST('./data', train=True, download=True, transform=transform)
test_data = MNIST('./data', train=False, download=True, transform=transform)
# 创建数据集加载器
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
```
在上面的代码中,我们首先定义了一个`transforms.Compose`对象,该对象包含两个操作:`transforms.ToTensor()`将数据转换为PyTorch张量,`transforms.Normalize()`标准化数据。然后我们加载了MNIST数据集并使用定义的预处理操作进行转换。最后,我们创建了两个数据集加载器,一个用于训练数据,一个用于测试数据。
阅读全文