解析cifar10数据的python代码
时间: 2023-10-16 08:24:06 浏览: 109
以下是解析CIFAR-10数据集的Python代码示例:
```python
import pickle
import numpy as np
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def load_data():
train_data = None
train_labels = []
for i in range(1, 6):
data_dict = unpickle("data_batch_{}".format(i))
if i == 1:
train_data = data_dict[b"data"]
else:
train_data = np.vstack((train_data, data_dict[b"data"]))
train_labels += data_dict[b"labels"]
test_data_dict = unpickle("test_batch")
test_data = test_data_dict[b"data"]
test_labels = test_data_dict[b"labels"]
return train_data, np.array(train_labels), test_data, np.array(test_labels)
train_data, train_labels, test_data, test_labels = load_data()
```
运行以上代码后,将会得到以下四个变量:
- `train_data`: 形状为 `(50000, 3072)` 的numpy数组,每一行表示一张图片,总共有50000张训练集图片。
- `train_labels`: 形状为 `(50000,)` 的numpy数组,每个元素是一个整数,表示对应图片的标签。
- `test_data`: 形状为 `(10000, 3072)` 的numpy数组,每一行表示一张图片,总共有10000张测试集图片。
- `test_labels`: 形状为 `(10000,)` 的numpy数组,每个元素是一个整数,表示对应图片的标签。
注意:以上代码需要将CIFAR-10数据集中的数据文件和代码放在同一目录下。
阅读全文