def load_data(): root_dir = 'data' files = os.listdir(root_dir) data = [] for file in files: if file.endswith('.npy'): path = os.path.join(root_dir, file) data.append(np.load(path)) data = np.concatenate(data, axis=0) return data
时间: 2024-04-28 21:24:13 浏览: 152
这段代码定义了一个 load_data 函数,函数实现的功能是从 'data' 目录下读取所有以 '.npy' 结尾的文件,并将它们读入内存中,最后使用 numpy.concatenate 方法将这些数据拼接成一个大的数据集,最终返回这个数据集。
具体实现过程如下:
1. 定义变量 root_dir 并将其赋值为 'data',表示数据存放的目录路径。
2. 使用 os.listdir 方法获取目录下的所有文件列表,并将其存储在变量 files 中。
3. 定义一个空列表 data,用于存储所有读取的数据。
4. 遍历 files 列表中的每个文件,如果文件以 '.npy' 结尾,则读取该文件,并将其添加到 data 列表中。
5. 使用 numpy.concatenate 方法将 data 列表中的所有数据拼接成一个大的数据集。
6. 返回拼接后的数据集。
需要注意的是,这段代码使用了 numpy 库来处理数据,因此在使用前需要先导入 numpy 库。另外,在使用 os.path.join 方法拼接文件路径时,需要保证 root_dir 和文件名之间使用斜杠进行分隔。
阅读全文