tensorflow本地数据集读取
时间: 2023-09-16 16:11:24 浏览: 45
在 TensorFlow 中,可以使用 `tf.data` API 读取本地数据集。下面是一个例子:
```python
import tensorflow as tf
# 定义数据集路径
data_dir = "/path/to/data"
# 创建数据集对象
dataset = tf.data.TFRecordDataset(filenames=[data_dir])
# 解析数据集并进行预处理
def parse_fn(example_proto):
features = {"image": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64)}
parsed_features = tf.io.parse_single_example(example_proto, features)
image = tf.io.decode_jpeg(parsed_features['image'], channels=3)
image = tf.image.resize(image, [224, 224])
label = parsed_features['label']
return image, label
# 对数据集进行处理,包括乱序、批处理、预处理等
batch_size = 32
dataset = dataset.map(parse_fn)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
# 构建迭代器并获取数据
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
images, labels = iterator.get_next()
```
上述代码中,首先定义了数据集路径 `data_dir`,然后使用 `TFRecordDataset` 创建数据集对象。接着定义了一个解析函数 `parse_fn`,用于解析 TFRecord 数据并进行预处理。最后,对数据集进行乱序、批处理、预处理等操作,并使用迭代器获取数据。
需要注意的是,上述代码中使用了 TensorFlow 2.0 的 API,如果使用 TensorFlow 1.x 版本,需要对一些 API 进行修改。