torch怎么读取图片
时间: 2023-12-02 09:27:11 浏览: 67
在 PyTorch 中,可以使用 `torchvision` 库来读取图片。以下是一个简单的例子:
```python
import torch
from torchvision import datasets, transforms
# 定义数据转换,例如将图片转换成 PyTorch Tensor 格式,并进行归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载数据集
dataset = datasets.ImageFolder(root='path/to/dataset', transform=transform)
# 创建数据加载器
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
# 读取数据
for images, labels in loader:
# images 是一个大小为 [32, 3, 224, 224] 的 Tensor,表示 32 张大小为 224x224、通道数为 3 的图片
# labels 是一个大小为 [32] 的 Tensor,表示 32 张图片对应的标签
pass
```
需要注意的是,`datasets.ImageFolder` 函数会自动遍历指定目录下的所有子目录,并将每个子目录中的图片视为一个类别。因此,如果你的数据集中有多个类别,需要将它们分别放在不同的子目录中。例如:
```
path/to/dataset/
├── cat/
│ ├── cat1.jpg
│ ├── cat2.jpg
│ └── ...
└── dog/
├── dog1.jpg
├── dog2.jpg
└── ...
```
阅读全文