保存CIFAR数据集500张32*32像素和对应的标签的代码
时间: 2023-11-25 18:04:16 浏览: 71
好的,以下是保存CIFAR数据集500张32*32像素和对应的标签的代码:
```python
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 定义数据转换
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载CIFAR-10数据集
cifar_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# 取出前500张图片和标签
data = []
labels = []
for i in range(500):
data.append(cifar_dataset[i][0])
labels.append(cifar_dataset[i][1])
# 将数据和标签转换为torch.tensor格式
data_tensor = torch.stack(data, dim=0)
labels_tensor = torch.tensor(labels)
# 保存数据和标签
torch.save(data_tensor, 'cifar_data_500.pt')
torch.save(labels_tensor, 'cifar_labels_500.pt')
```
这段代码中,我们首先定义了数据转换,然后加载CIFAR-10数据集,并取出前500张图片和标签。接着,我们将数据和标签转换为torch.tensor格式,并使用torch.save()函数将它们分别保存到`cifar_data_500.pt`和`cifar_labels_500.pt`文件中。
阅读全文