pytorch 读取文件夹中的图片为 [batch_size, num_channels, height, width]
时间: 2023-06-01 07:04:04 浏览: 105
下面是一个示例代码,使用PyTorch中的`torchvision`库来读取文件夹中的图片并转换为指定的张量形状:
```python
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# 定义数据预处理操作
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
# 创建数据集对象
dataset = datasets.ImageFolder(root='/path/to/folder', transform=transform)
# 创建数据加载器对象
batch_size = 32
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 读取一个批次的图片数据
images, labels = next(iter(loader))
# 打印张量形状
print(images.shape) # 输出:[batch_size, num_channels, height, width]
```
在上面的代码中,`transform`参数定义了一系列数据预处理操作,包括将图片缩放为256x256大小,并将其转换为张量形式。然后,使用`ImageFolder`类创建了一个数据集对象,它会自动从指定的文件夹中读取图片,并将其应用到定义好的预处理操作。最后,使用`DataLoader`类创建了一个数据加载器对象,它可以对数据集进行批量读取和随机打乱等操作。通过调用`next(iter(loader))`方法,可以读取一个批次的图片数据,并将其转换为指定的张量形状。
阅读全文