# 数据集加载函数,指明数据集的位置并统一处理为imgheight*imgwidth的大小,同时设置batch def data_load(data_dir, test_data_dir, img_height, img_width, batch_size): # 加载训练集 train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, label_mode='categorical', seed=123, image_size=(img_height, img_width), batch_size=batch_size) # 加载测试集 val_ds = tf.keras.preprocessing.image_dataset_from_directory( test_data_dir, label_mode='categorical', seed=123, image_size=(img_height, img_width), batch_size=batch_size) class_names = train_ds.class_names # 返回处理之后的训练集、验证集和类名 return train_ds, val_ds, class_names
时间: 2023-12-24 20:05:52 浏览: 129
数据集包括训练集和测试集
这段代码使用了 TensorFlow 的 `preprocessing.image_dataset_from_directory` 函数来加载训练集和测试集,并将它们处理为指定的 `img_height` 和 `img_width` 大小。同时,它还将标签处理为独热编码(one-hot encoding)的形式,因为 `label_mode` 参数被设置为 `'categorical'`。最后,这个函数返回处理之后的训练集、验证集和类名。
需要注意的是,这个函数假设数据集的目录结构是按照类别分组的,每个目录的名称即为对应类别的名称。例如,如果有一个猫狗分类的数据集,那么数据集的目录结构可能如下所示:
```
data/
├── cat/
│ ├── cat001.jpg
│ ├── cat002.jpg
│ └── ...
└── dog/
├── dog001.jpg
├── dog002.jpg
└── ...
```
其中,`cat/` 和 `dog/` 分别是两个类别的目录,里面包含了对应类别的图片文件。在这种情况下,`class_names` 将会是一个包含 `'cat'` 和 `'dog'` 两个字符串的列表。
阅读全文