python使用torch读取mnist数据集
时间: 2024-11-09 11:29:46 浏览: 15
python进阶教程:pytorch 把MNIST数据集转换成图片和txt的方法
在PyTorch中,你可以使用`torchvision`库来方便地加载MNIST数据集,这是一个经典的手写数字识别数据集。以下是使用`torchvision.datasets.MNIST`加载和预处理MNIST数据的基本步骤:
```python
import torch
from torchvision import datasets, transforms
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化到[0,1]范围
])
# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 获取训练集和测试集的数据和标签
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) # 创建数据加载器,分批读取数据
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 对于每个批次的样本,它们都是一个包含两个元素的元组 (images, labels),images是张量,labels是整数
for images, labels in train_loader:
print(images.shape, labels.shape)
阅读全文