pytorch stl10数据预处理
时间: 2023-11-14 12:24:56 浏览: 491
STL-10是一个常用的图像数据集,它包含10个类别的图像,每个类别有500张训练图像和800张测试图像。下面是使用PyTorch进行STL-10数据预处理的示例代码:
```python
import torch
import torchvision.transforms as transforms
from torchvision.datasets import STL10
# 定义预处理函数
transform = transforms.Compose([
transforms.Resize((96, 96)), # 调整图像大小为96x96
transforms.ToTensor(), # 将图像转换为Tensor格式
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 对图像进行归一化
])
# 加载数据集
train_dataset = STL10(root='./data', split='train', transform=transform, download=True)
test_dataset = STL10(root='./data', split='test', transform=transform, download=True)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
```
在上述代码中,我们使用了PyTorch提供的transforms模块对图像进行预处理,包括调整图像大小、转换为Tensor格式以及归一化。然后使用STL10数据集的API加载数据集,并创建数据加载器,以便于我们在训练和测试时使用。
阅读全文