cifia数据集的代码
时间: 2023-07-09 14:32:44 浏览: 73
以下是使用Python代码读取CIFAR-10数据集,并进行简单预处理的示例代码:
```python
import numpy as np
import pickle
# 读取数据集
def load_cifar10_data(data_dir):
train_data = None
train_labels = []
for batch_id in range(1, 6):
with open(data_dir + '/data_batch_' + str(batch_id), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
if train_data is None:
train_data = batch['data']
else:
train_data = np.vstack((train_data, batch['data']))
train_labels += batch['labels']
with open(data_dir + '/test_batch', mode='rb') as file:
test_batch = pickle.load(file, encoding='latin1')
test_data = test_batch['data']
test_labels = test_batch['labels']
return train_data, np.array(train_labels), test_data, np.array(test_labels)
# 对数据进行预处理
def preprocess_data(train_data, test_data):
# 将数据标准化到0-1之间
train_data = train_data / 255.0
test_data = test_data / 255.0
# 将数据转换为浮点型
train_data = train_data.astype(np.float32)
test_data = test_data.astype(np.float32)
return train_data, test_data
# 读取数据集
data_dir = 'cifar-10-batches-py'
train_data, train_labels, test_data, test_labels = load_cifar10_data(data_dir)
# 对数据进行预处理
train_data, test_data = preprocess_data(train_data, test_data)
# 打印数据集信息
print('Training data shape: ', train_data.shape)
print('Training labels shape: ', train_labels.shape)
print('Test data shape: ', test_data.shape)
print('Test labels shape: ', test_labels.shape)
```
上述代码中,load_cifar10_data函数用于读取数据集,返回训练数据、训练标签、测试数据和测试标签。preprocess_data函数用于对数据进行预处理,包括将像素值标准化到0-1之间,转换为浮点型等操作。最后,打印数据集的信息,包括训练数据、训练标签、测试数据和测试标签的形状。
阅读全文