从本地文件加载cifar10数据集的代码
时间: 2023-05-21 07:00:56 浏览: 59
以下是从本地文件加载cifar10数据集的代码:
```python
import pickle
import numpy as np
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def load_cifar10_data(data_dir):
train_data = None
train_labels = []
for i in range(1, 6):
data_dict = unpickle(data_dir + "/data_batch_{}".format(i))
if i == 1:
train_data = data_dict[b'data']
else:
train_data = np.vstack((train_data, data_dict[b'data']))
train_labels += data_dict[b'labels']
test_data_dict = unpickle(data_dir + "/test_batch")
test_data = test_data_dict[b'data']
test_labels = test_data_dict[b'labels']
train_data = train_data.reshape((len(train_data), 3, 32, 32)).transpose(0, 2, 3, 1)
test_data = test_data.reshape((len(test_data), 3, 32, 32)).transpose(0, 2, 3, 1)
return train_data, np.array(train_labels), test_data, np.array(test_labels)
```
希望这个代码能够帮到你!