读取CIFAR-10数据集
时间: 2023-12-21 07:05:46 浏览: 43
CIFAR-10数据集是一个用于图像识别的经典数据集,包含10个类别的60000张32x32彩色图像,每个类别有6000张图像。
下面是用Python读取CIFAR-10数据集的示例代码:
```python
import pickle
import numpy as np
def unpickle(file):
with open(file, 'rb') as f:
dict = pickle.load(f, encoding='bytes')
return dict
# 读取训练集数据
train_data = None
train_labels = []
for i in range(1, 6):
data_dict = unpickle('cifar-10-batches-py/data_batch_{}'.format(i))
if i == 1:
train_data = data_dict[b'data']
else:
train_data = np.vstack((train_data, data_dict[b'data']))
train_labels += data_dict[b'labels']
# 读取测试集数据
test_data_dict = unpickle('cifar-10-batches-py/test_batch')
test_data = test_data_dict[b'data']
test_labels = test_data_dict[b'labels']
# 将数据转换为图像格式
train_data = train_data.reshape((len(train_data), 3, 32, 32)).transpose(0, 2, 3, 1)
test_data = test_data.reshape((len(test_data), 3, 32, 32)).transpose(0, 2, 3, 1)
# 显示数据信息
print('训练集数据:', train_data.shape)
print('训练集标签:', len(train_labels))
print('测试集数据:', test_data.shape)
print('测试集标签:', len(test_labels))
```
运行以上代码,即可读取CIFAR-10数据集。需要注意的是,需要将数据转换为图像格式,并且标签是从0开始的数字。你可以根据需要对数据进行预处理。