tf2 加载tfrecord 代码
时间: 2024-03-11 16:18:28 浏览: 190
TFRecord文件查看包含的所有Features代码
以下是 TensorFlow 2.x 中加载 TFRecord 的示例代码:
```python
import tensorflow as tf
# 定义解析函数
def parse_example(example_proto):
feature_description = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64),
}
# 解析单个样本
example = tf.io.parse_single_example(example_proto, feature_description)
# 解码图像数据
image = tf.io.decode_jpeg(example['image'], channels=3)
# 归一化图像数据
image = tf.cast(image, tf.float32) / 255.0
# 返回图像和标签
return image, example['label']
# 加载 TFRecord 文件
filenames = ['train.tfrecord']
raw_dataset = tf.data.TFRecordDataset(filenames)
# 解析 TFRecord 数据
dataset = raw_dataset.map(parse_example)
# 打印前 10 个样本
for image, label in dataset.take(10):
print('Image shape: ', image.shape)
print('Label: ', label)
```
在上述代码中,我们首先定义了一个解析函数 `parse_example`,它将 TFRecord 中的每个样本解析成图像和标签。然后我们使用 `TFRecordDataset` 加载 TFRecord 文件,并使用 `map` 方法将每个样本解析成图像和标签。最后,我们使用 `take` 方法打印前 10 个样本。
阅读全文