给出python读取cifar-10-python.tar.gz的代码
时间: 2023-07-30 18:06:45 浏览: 109
以下是Python读取CIFAR-10数据集的示例代码:
```python
import pickle
import numpy as np
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def load_cifar10_data(data_dir):
train_data = []
train_labels = []
for i in range(1, 6):
filename = data_dir + '/data_batch_{}'.format(i)
data_dict = unpickle(filename)
train_data.append(data_dict[b'data'])
train_labels.append(data_dict[b'labels'])
train_data = np.concatenate(train_data, axis=0)
train_labels = np.concatenate(train_labels, axis=0)
test_data_dict = unpickle(data_dir + '/test_batch')
test_data = test_data_dict[b'data']
test_labels = np.array(test_data_dict[b'labels'])
return train_data, train_labels, test_data, test_labels
data_dir = 'path/to/cifar-10-batches-py'
train_data, train_labels, test_data, test_labels = load_cifar10_data(data_dir)
```
这个代码加载CIFAR-10数据集,其中train_data是训练图像数据,train_labels是训练图像标签,test_data是测试图像数据,test_labels是测试图像标签。
阅读全文