tensorflow批量读取图片文件并转换为张量
时间: 2023-06-18 12:06:50 浏览: 110
tensorflow如何批量读取图片
可以使用TensorFlow中的`tf.data.Dataset` API来批量读取和预处理图片数据。
以下是一个示例代码,假设要读取目录中所有的JPEG图像文件:
```python
import tensorflow as tf
# 定义一些常量
IMAGE_SIZE = 224
BATCH_SIZE = 32
NUM_CLASSES = 10
NUM_EPOCHS = 10
# 构建数据集
def parse_fn(filename, label):
# 读取图片数据
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
# 调整图片大小
image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
# 数据增强
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
# 转换为张量
image = tf.keras.applications.resnet50.preprocess_input(image)
return image, label
# 读取图片路径和标签信息
image_paths = [...] # 图片路径列表
labels = [...] # 标签列表
# 构建数据集
ds = tf.data.Dataset.from_tensor_slices((image_paths, labels))
ds = ds.map(parse_fn)
ds = ds.shuffle(buffer_size=len(image_paths))
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
# 构建模型
model = tf.keras.applications.ResNet50(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
weights=None,
classes=NUM_CLASSES)
# 编译模型
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(ds, epochs=NUM_EPOCHS)
```
在上述代码中,`parse_fn`函数用于对每个图片文件进行预处理和转换为张量。`tf.data.Dataset.from_tensor_slices`函数用于将图片路径和标签信息组成一个元组,然后通过`map`函数将每个元组转换为对应的图片张量和标签。最后,通过`shuffle`、`batch`和`prefetch`函数将数据集随机打乱、分批次处理和提前加载。
注意,上述代码中使用了`tf.keras.applications.resnet50.preprocess_input`函数对图片进行了预处理,这是为了将图片数据转换为ResNet50模型所需的格式。如果你使用的是其他的模型,可能需要使用不同的预处理方式。
阅读全文