def load_data(file_dir): all_num = 4000 train_num = int(all_num * 0.75) cats = [] label_cats = [] dogs = [] label_dogs = [] for file in os.listdir(file_dir): file="\\"+file name = file.split(sep='.') if 'cat' in name[0]: cats.append(file_dir + file) label_cats.append(0) else: if 'dog' in name[0]: dogs.append(file_dir + file) label_dogs.append(1) image_list = np.hstack((cats,dogs)) label_list = np.hstack((label_cats, label_dogs)) temp = np.array([image_list, label_list])
时间: 2024-02-14 19:22:43 浏览: 68
这段代码是一个用于加载数据的函数,主要功能是从指定路径 file_dir 中读取图片文件,并将它们转换为 numpy 数组。具体来说,该函数首先定义了一些变量,包括总共的图片数量 all_num,训练集的图片数量 train_num,以及用于存储图片路径和标签的空列表 cats、label_cats、dogs、label_dogs。接着,使用 os.listdir(file_dir) 函数遍历 file_dir 目录下的所有文件,并将其中的猫和狗的图片路径和标签分别存储到 cats、label_cats、dogs、label_dogs 中。这里使用了字符串操作和列表操作,通过判断文件名中是否包含 'cat' 和 'dog' 来确定图片的标签。然后,使用 np.hstack 函数将 cats 和 dogs 列表合并成一个 image_list 列表,将 label_cats 和 label_dogs 列表合并成一个 label_list 列表。最后,使用 np.array 将 image_list 和 label_list 列表转换为 numpy 数组,并返回结果。需要注意的是,该函数并没有对图片进行读取和预处理的操作,只是简单地将图片路径和标签存储到了列表中。
阅读全文