TensorFlow图片标准化:tf.image.per_image_standardization详解与示例

1 下载量 49 浏览量 更新于2024-08-31 收藏 170KB PDF 举报
"tensorflow下的图片标准化函数per_image_standardization用法" 在深度学习中,数据预处理是一个非常重要的步骤,因为它可以极大地影响模型的训练效率和最终的性能。TensorFlow 提供了一个名为 `tf.image.per_image_standardization` 的函数,用于对输入的图像进行标准化处理,以确保所有特征在同一尺度上,这对于神经网络的训练尤其有利。 函数 `tf.image.per_image_standardization(image)` 是一个针对单张图像的标准化操作。它通过减去图像每个通道的均值然后除以调整后的标准差来实现。这个过程可以表达为 `(x - mean) / adjusted_stddev`,其中 `x` 是图像的 RGB 三通道像素值,`mean` 分别代表三个通道的像素均值,而 `adjusted_stddev` 是计算出的每个通道的标准差,如果标准差为零,则使用 `1.0/sqrt(image.NumElements())` 来防止除以零的情况,`image.NumElements()` 计算的是图像像素总数。 标准化的好处在于它可以使不同特征的影响保持一致,这对于使用梯度下降法优化权重时尤其重要,因为这可以避免某些特征由于数值范围过大或过小而占据主导地位。此外,标准化还可以加速训练过程,因为网络不需要花费过多时间来适应数据的不同尺度。 以下是一段使用该函数的 Python 代码示例: ```python import tensorflow as tf import matplotlib.image as img import matplotlib.pyplot as plt import numpy as np sess = tf.InteractiveSession() image = img.imread('D:/Documents/Pictures/logo7.jpg') # 获取图像形状 shape = tf.shape(image).eval() h, w = shape[0], shape[1] # 对图像进行标准化处理 standardization_image = tf.image.per_image_standardization(image) # 绘制原图和标准化后的直方图 fig = plt.figure() fig1 = plt.figure() ax = fig.add_subplot(111) ax.set_title('原始图像') ax.imshow(image) ax1 = fig1.add_subplot(311) ax1.set_title('原始直方图') ax1.hist(sess.run(tf.reshape(image, [h * w, -1]))) ax1 = fig1.add_subplot(313) ax1.set_title('标准化直方图') ax1.hist(sess.run(tf.reshape(standardization_image, [h * w, -1]))) ``` 这段代码首先读取图像,然后应用 `tf.image.per_image_standardization` 函数进行标准化,最后绘制原始图像及其直方图,以及标准化后的直方图,以便直观地比较标准化前后的差异。 `tf.image.per_image_standardization` 是 TensorFlow 中处理图像数据的一个实用工具,它可以有效地为深度学习模型提供经过预处理的数据,从而提高模型的训练效率和准确性。在实际应用中,结合其他预处理技术,如数据增强,可以进一步提升模型的泛化能力。

以下代码有什么错误,怎么修改: import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from PIL import Image import matplotlib.pyplot as plt import input_data import model import numpy as np import xlsxwriter num_threads = 4 def evaluate_one_image(): workbook = xlsxwriter.Workbook('formatting.xlsx') worksheet = workbook.add_worksheet('My Worksheet') with tf.Graph().as_default(): BATCH_SIZE = 1 N_CLASSES = 4 image = tf.cast(image_array, tf.float32) image = tf.image.per_image_standardization(image) image = tf.reshape(image, [1, 208, 208, 3]) logit = model.cnn_inference(image, BATCH_SIZE, N_CLASSES) logit = tf.nn.softmax(logit) x = tf.placeholder(tf.float32, shape=[208, 208, 3]) logs_train_dir = 'log/' saver = tf.train.Saver() with tf.Session() as sess: print("从指定路径中加载模型...") ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('模型加载成功, 训练的步数为: %s' % global_step) else: print('模型加载失败,checkpoint文件没找到!') prediction = sess.run(logit, feed_dict={x: image_array}) max_index = np.argmax(prediction) workbook.close() def evaluate_images(test_img): coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for index,img in enumerate(test_img): image = Image.open(img) image = image.resize([208, 208]) image_array = np.array(image) tf.compat.v1.threading.Thread(target=evaluate_one_image, args=(image_array, index)).start() coord.request_stop() coord.join(threads) if __name__ == '__main__': test_dir = 'data/test/' import glob import xlwt test_img = glob.glob(test_dir + '*.jpg') evaluate_images(test_img)

2023-07-08 上传