将stl10数据集保存为npy文件
时间: 2024-05-02 15:19:22 浏览: 198
可以使用Python中的NumPy库直接将STL10数据集保存为npy文件。可以按照以下步骤操作:
1. 下载STL10数据集,并导入需要的模块和库:
```python
import numpy as np
import os
from skimage import io
```
2. 通过以下代码读取数据:
```python
def read_stl10_data(path):
data = {}
data['X_train'] = read_single_file(os.path.join(path, 'train_X.bin'), np.uint8, (5000, 3, 96, 96))
data['X_test'] = read_single_file(os.path.join(path, 'test_X.bin'), np.uint8, (8000, 3, 96, 96))
data['y_train'] = read_labels(os.path.join(path, 'train_y.bin')) - 1
data['y_test'] = read_labels(os.path.join(path, 'test_y.bin')) - 1
return data
def read_single_file(filename, dtype, shape):
with open(filename, 'rb') as f:
return np.fromfile(f, dtype=dtype).reshape(shape)
def read_labels(filename):
with open(filename, 'rb') as f:
return np.fromfile(f, dtype=np.uint8)
```
3. 将数据集保存为npy文件:
```python
data_path = '/path/to/stl10_data/'
data = read_stl10_data(data_path)
np.save('stl10_data.npy', data)
```
这样,STL10数据集就会被保存为名为“stl10_data.npy”的文件。
阅读全文