def _load_img(file_name): file_path=dataset_dir+"/"+file_name print("Converting"+file_name+"to Numpy Array ...") with gzip.open(file_path,'rb') as f: data = np.frombuffer(f.read(),np.uint8,offset=16) data = data.reshape(-1,img_size) print("Done") return data
时间: 2024-04-28 19:20:24 浏览: 94
这段代码是一个函数,函数名为 `_load_img`,接受一个参数 `file_name`。该函数主要完成读取经过 gzip 压缩的图像文件,将其转换为 Numpy 数组并返回。
具体实现流程如下:
1. 定义变量 `file_path`,用于存储数据集文件夹路径和文件名的组合。
2. 打印提示信息,表示正在将文件转换为 Numpy 数组。
3. 使用 gzip 库的 `open` 函数读取并打开指定路径下的文件,并将文件内容转换为 Numpy 数组,数据类型为 `np.uint8`,偏移量为 16。
4. 使用 Numpy 的 `reshape` 函数将数组形状重塑为 (-1,img_size),其中 -1 表示自动推断维度,img_size 表示图像大小。
5. 打印提示信息,表示转换完成。
6. 返回转换后的 Numpy 数组。
相关问题
import os.path import gzip import pickle import os import numpy as np import urllib url_base = 'http://yann.lecun.com/exdb/mnist/' key_file = { 'train_img':'train-images-idx3-ubyte.gz', 'train_label':'train-labels-idx1-ubyte.gz', 'test_img':'t10k-images-idx3-ubyte.gz', 'test_label':'t10k-labels-idx1-ubyte.gz' } dataset_dir = os.path.dirname(os.path.abspath("_file_")) save_file = dataset_dir + "/mnist.pkl" train_num=60000 test_num=10000 img_dim=(1,28,28) img_size=784 def _download(file_name): file_path = dataset_dir+"/"+file_name if os.path.exists(file_path): return print("Downloading"+file_name+" ... ") urllib.request.urlretrieve(url_base + file_name,file_path) print("Done") def download_mnist(): for v in key_file.values(): _download(v) def _load_label(file_name): file_path = dataset_dir+ "/" +file_name print("Converting" + file_name +"to Numpy Array ...") with gzip.open(file_path,'rb') as f: labels = np.frombuffer(f.read(),np.uint8,offset=8) print("Done") return labels def _load_img(file_name): file_path=dataset_dir+"/"+file_name print("Converting"+file_name+"to Numpy Array ...") with gzip.open(file_path,'rb') as f: data = np.frombuffer(f.read(),np.uint8,offset=16) data = data.reshape(-1,img_size) print("Done") return data def _convert_numpy(): dataset = {} dataset['train_img'] = _load_img(key_file['train_img']) dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img']) dataset['test_label'] = _load_label(key_file['test_label']) return dataset def init_mnist(): download_mnist() dataset = _convert_numpy() print("Creating pickle file ...") with open(save_file,'wb') as f: pickle.dump(dataset,f,-1) print("Done") if __name__ =='__main__': init_mnist()
这段代码是用于下载MNIST数据集并将数据集转换成Numpy数组格式的函数。MNIST数据集是一个手写数字识别数据集,包含了60000张训练图片和10000张测试图片。在函数中,首先定义了数据集的下载地址和四个文件的名称,然后定义了四个函数用于下载和转换数据集。其中,_load_label和_load_img函数用于将标签和图片数据转换成Numpy数组格式,并且在转换过程中使用了gzip库解压缩数据。_convert_numpy函数用于将四个Numpy数组合成一个字典类型的数据集。最后,init_mnist函数用于下载数据集并将数据集转换成Numpy数组格式,并使用pickle库将数据集保存到本地文件中。如果MNIST数据集已经下载并保存到本地文件中,则直接加载本地文件中的数据集。
X_train,T_train=idx2numpy.convert_from_file('emnist/emnist-letters-train-images-idx3-ubyte'),idx2numpy.convert_from_file('emnist/emnist-letters-train-labels-idx1-ubyte')转化为相同形式train_num = 60000 test_num = 10000 img_dim = (1, 28, 28) img_size = 784 def _download(file_name): file_path = dataset_dir + "/" + file_name if os.path.exists(file_path): return print("Downloading " + file_name + " ... ") urllib.request.urlretrieve(url_base + file_name, file_path) print("Done") def download_mnist(): for v in key_file.values(): _download(v) def _load_label(file_name): file_path = dataset_dir + "/" + file_name print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f: labels = np.frombuffer(f.read(), np.uint8, offset=8) print("Done") return labels def _load_img(file_name): file_path = dataset_dir + "/" + file_name print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f: data = np.frombuffer(f.read(), np.uint8, offset=16) data = data.reshape(-1, img_size) print("Done") return data def _convert_numpy(): dataset = {} dataset['train_img'] = _load_img(key_file['train_img']) dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img']) dataset['test_label'] = _load_label(key_file['test_label']) return dataset def init_mnist(): download_mnist() dataset = _convert_numpy() print("Creating pickle file ...") with open(save_file, 'wb') as f: pickle.dump(dataset, f, -1) print("Done!") def _change_one_hot_label(X): T = np.zeros((X.size, 10)) for idx, row in enumerate(T): row[X[idx]] = 1 return T def load_mnist(normalize=True, flatten=True, one_hot_label=False): """读入MNIST数据集 Parameters ---------- normalize : 将图像的像素值正规化为0.0~1.0 one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回 one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组 flatten : 是否将图像展开为一维数组 Returns ------- (训练图像, 训练标签), (测试图像, 测试标签) """ if not os.path.exists(save_file): init_mnist() with open(save_file, 'rb') as f: dataset = pickle.load(f) if normalize: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].astype(np.float32) dataset[key] /= 255.0 if one_hot_label: dataset['train_label'] = _change_one_hot_label(dataset['train_label']) dataset['test_label'] = _change_one_hot_label(dataset['test_label']) if not flatten: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].reshape(-1, 1, 28, 28) return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if name == 'main': init_mnist()模仿这段代码将获取同样形式
你提供的是MNIST数据集的加载代码,而不是EMNIST。如果要加载EMNIST数据集,可以使用以下代码:
```python
import idx2numpy
import numpy as np
def load_emnist():
X_train = idx2numpy.convert_from_file('emnist/emnist-letters-train-images-idx3-ubyte')
T_train = idx2numpy.convert_from_file('emnist/emnist-letters-train-labels-idx1-ubyte')
X_test = idx2numpy.convert_from_file('emnist/emnist-letters-test-images-idx3-ubyte')
T_test = idx2numpy.convert_from_file('emnist/emnist-letters-test-labels-idx1-ubyte')
# 将数据展开为一维数组
X_train = X_train.reshape(X_train.shape[0], -1)
X_test = X_test.reshape(X_test.shape[0], -1)
# 将标签转换为one-hot编码
T_train = np.eye(26)[T_train]
T_test = np.eye(26)[T_test]
return (X_train, T_train), (X_test, T_test)
```
这个代码将返回一个tuple,其中第一个元素是训练数据,第二个元素是测试数据,每个数据都是一个tuple,其中第一个元素是图像数据,第二个元素是标签数据。图像数据是展开为一维数组的,标签数据是one-hot编码的。
阅读全文