用pytorch怎么获取
时间: 2023-08-04 12:16:18 浏览: 35
使用PyTorch加载MNIST数据集的代码如下:
```python
import torch
import torchvision
from torchvision import transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为PyTorch张量
transforms.Normalize((0.1307,), (0.3081,)) # 归一化像素值到-1到1之间
])
# 加载MNIST数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 获取前5张训练图片
few_train_images = [trainset[i][0] for i in range(5)]
```
在上面的代码中,我们使用了PyTorch内置的`torchvision.datasets.MNIST`模块来加载MNIST数据集。我们使用了`transforms.Compose`方法来定义数据预处理,将图片转换为PyTorch张量并归一化像素值到-1到1之间。然后,我们可以通过指定`train=True`来加载训练集,`train=False`来加载测试集。最后,我们可以使用Python的列表推导式来获取前5张训练图片。