cifar 数据集格式代码
时间: 2023-12-25 17:04:47 浏览: 20
以下是一个示例代码,可以读取 CIFAR 数据集并将其作为 numpy 数组返回:
```python
import numpy as np
import pickle
def load_cifar10_data(data_dir):
# 训练集
with open(os.path.join(data_dir, 'data_batch_1'), 'rb') as f:
data_dict = pickle.load(f, encoding='bytes')
train_X = data_dict[b'data']
train_Y = data_dict[b'labels']
for i in range(2, 6):
with open(os.path.join(data_dir, 'data_batch_{}'.format(i)), 'rb') as f:
data_dict = pickle.load(f, encoding='bytes')
train_X = np.vstack((train_X, data_dict[b'data']))
train_Y += data_dict[b'labels']
# 测试集
with open(os.path.join(data_dir, 'test_batch'), 'rb') as f:
data_dict = pickle.load(f, encoding='bytes')
test_X = data_dict[b'data']
test_Y = data_dict[b'labels']
# 归一化像素值到 [0, 1] 之间
train_X = train_X.astype('float32') / 255.0
test_X = test_X.astype('float32') / 255.0
# 重新形状为 (数量, 通道数, 高度, 宽度)
train_X = np.reshape(train_X, (-1, 3, 32, 32))
test_X = np.reshape(test_X, (-1, 3, 32, 32))
# 调整通道顺序为 (数量, 高度, 宽度, 通道数)
train_X = np.transpose(train_X, (0, 2, 3, 1))
test_X = np.transpose(test_X, (0, 2, 3, 1))
# 转换标签为 numpy 数组
train_Y = np.array(train_Y)
test_Y = np.array(test_Y)
return train_X, train_Y, test_X, test_Y
```
要使用此代码,只需指定 CIFAR 数据集的目录,并调用 `load_cifar10_data` 函数即可:
```python
data_dir = '/path/to/cifar10/data'
train_X, train_Y, test_X, test_Y = load_cifar10_data(data_dir)
```