在torch中导入数据
时间: 2024-10-12 13:07:09 浏览: 26
在PyTorch中,导入数据通常涉及几个关键步骤。首先,你需要安装`torchvision`库,它包含了处理图像、视频等常见数据集的工具。如果你还没有安装,可以使用pip命令:
```bash
pip install torchvision
```
接下来,你可以通过以下方式导入常见的数据集,例如MNIST手写数字数据集:
```python
import torch
from torchvision import datasets, transforms
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化
])
# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 数据加载器
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_loader, batch_size=batch_size, shuffle=False)
```
这里,我们设置了训练数据和测试数据的加载器,`DataLoader`负责将数据分批读取,便于模型训练。
阅读全文