python代码载入stl10数据集的训练集和测试集,并保存为npy文件
时间: 2024-04-29 16:25:19 浏览: 301
可以的,以下是Python代码载入stl10数据集的训练集和测试集,并保存为npy文件的示例:
```
import numpy as np
import os
import urllib.request
import tarfile
from PIL import Image
def download_data():
url = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'
if not os.path.isfile('stl10_binary.tar.gz'):
urllib.request.urlretrieve(url, 'stl10_binary.tar.gz')
if not os.path.isdir('stl10_binary'):
with tarfile.open('stl10_binary.tar.gz') as tar:
tar.extractall()
def read_data(split):
path = 'stl10_binary'
if split == 'train':
filename = os.path.join(path, 'train_X.bin')
with open(filename, 'rb') as f:
X_train = np.fromfile(f, dtype=np.uint8)
X_train = np.reshape(X_train, (-1,3,96,96))
X_train = np.transpose(X_train, (0,3,2,1))
filename = os.path.join(path, 'train_y.bin')
with open(filename, 'rb') as f:
y_train = np.fromfile(f, dtype=np.uint8)
return X_train, y_train
elif split == 'test':
filename = os.path.join(path, 'test_X.bin')
with open(filename, 'rb') as f:
X_test = np.fromfile(f, dtype=np.uint8)
X_test = np.reshape(X_test, (-1,3,96,96))
X_test = np.transpose(X_test, (0,3,2,1))
filename = os.path.join(path, 'test_y.bin')
with open(filename, 'rb') as f:
y_test = np.fromfile(f, dtype=np.uint8)
return X_test, y_test
def save_data(X, y, filename):
np.save(filename, {'X': X, 'y': y})
if __name__ == '__main__':
download_data()
X_train, y_train = read_data('train')
X_test, y_test = read_data('test')
save_data(X_train, y_train, 'train.npy')
save_data(X_test, y_test, 'test.npy')
```
这段代码会首先下载STL10数据集的压缩文件,然后解压出训练集和测试集的图像和标签,并转换为numpy数组的格式,并将其保存为npy文件。
阅读全文