TensorFlow图片标准化per_image_standardization详解与应用

2 下载量 157 浏览量 更新于2024-09-01 收藏 169KB PDF 举报
在TensorFlow中,图片标准化是一种预处理技术,它有助于统一不同特征的尺度,使它们在训练神经网络时具有相等的重要性。per_image_standardization是TensorFlow提供的一个函数,专门用于对每一幅图片进行独立的标准化处理,而不是全局归一化。这种处理方式有助于加快模型的收敛速度。 该函数的用法如下: 1. 环境配置:本文档基于Windows 7系统、Anaconda3(Python 3.5版本)以及TensorFlow(支持GPU或CPU)环境。 2. 函数原理:`tf.image.per_image_standardization(image)`函数的核心操作是对输入图片的每个像素进行标准化,即`(x - mean) / adjusted_stddev`。这里,`x`代表RGB三通道的像素值,`mean`是三个通道的像素均值,而`adjusted_stddev`则是每个通道标准差的一个调整值,计算公式为`max(stddev, 1.0 / sqrt(image.NumElements()))`。`image.NumElements()`用于计算单通道像素的数量,确保标准化过程中不会因为极小的方差导致除数接近于零。 3. 实验示例: - 首先,导入所需的库:`tensorflow`、`matplotlib.image`、`matplotlib.pyplot`和`numpy`。 - 创建一个交互式会话:`sess=tf.InteractiveSession()` - 读取图片(例如logo7.jpg)并获取其形状:`image = img.imread('D:/Documents/Pictures/logo7.jpg')`,然后计算图片的高和宽。 - 使用`tf.image.per_image_standardization`对图片进行标准化处理:`standardization_image = tf.image.per_image_standardization(image)` - 使用Matplotlib展示原始图像和标准化后的图像,以及分别对每个通道的直方图进行可视化。 通过这个函数,你可以优化图像数据在神经网络中的表示,使其更有利于梯度下降等优化算法的学习。理解并掌握`per_image_standardization`的用法,对于提高深度学习模型在图像处理任务中的性能至关重要。在实际应用中,记得根据具体需求调整图片处理策略,并结合其他预处理技术(如归一化、裁剪等)以获得最佳效果。

import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from PIL import Image import matplotlib.pyplot as plt import input_data import model import numpy as np import xlsxwriter #设置线程数 num_threads = 4 def evaluate_one_image(): workbook = xlsxwriter.Workbook('formatting.xlsx') worksheet = workbook.add_worksheet('My Worksheet') with tf.Graph().as_default(): BATCH_SIZE = 1 N_CLASSES = 4 image = tf.cast(image_array, tf.float32) image = tf.image.per_image_standardization(image) image = tf.reshape(image, [1, 208, 208, 3]) logit = model.cnn_inference(image, BATCH_SIZE, N_CLASSES) logit = tf.nn.softmax(logit) x = tf.placeholder(tf.float32, shape=[208, 208, 3]) logs_train_dir = 'log/' saver = tf.train.Saver() with tf.Session() as sess: print("从指定路径中加载模型...") ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('模型加载成功, 训练的步数为: %s' % global_step) else: print('模型加载失败,checkpoint文件没找到!') prediction = sess.run(logit, feed_dict={x: image_array}) max_index = np.argmax(prediction) workbook.close() def evaluate_images(test_img): coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for index,img in enumerate(test_img): image = Image.open(img) image = image.resize([208, 208]) image_array = np.array(image) tf.compat.v1.threading.Thread(target=evaluate_one_image, args=(image_array, index)).start() # 请求停止所有线程 coord.request_stop() # 等待所有线程完成 coord.join(threads) if __name__ == '__main__': # 调用方法,开始测试 test_dir = 'data/test/' import glob import xlwt test_img = glob.glob(test_dir + '*.jpg') evaluate_images(test_img)

2023-07-08 上传