基于tensorflow2.5,用随机池化替换CBAM模块中所有的max pooling操作形成新的可以随便插入任何一个卷积神经网络的CBAM模块,这个模块默认输入inputs为224x224x3,并示例如何使用
时间: 2024-06-03 17:08:32 浏览: 48
import tensorflow as tf
# 定义随机池化函数
def random_pooling(inputs):
batch_size, height, width, channels = inputs.shape
# 随机生成池化区域大小
pool_size = tf.random.uniform(shape=[batch_size, 2], minval=1, maxval=height//2, dtype=tf.int32)
# 随机生成池化区域位置
pool_pos = tf.random.uniform(shape=[batch_size, 2], minval=0, maxval=height-pool_size, dtype=tf.int32)
# 构造池化区域
pool_regions = tf.image.extract_patches(inputs, sizes=[1, pool_size[0][0], pool_size[0][1], 1], strides=[1, pool_size[1][0], pool_size[1][1], 1], rates=[1, 1, 1, 1], padding='VALID')
# 取池化区域的平均值
pooled = tf.reduce_mean(pool_regions, axis=[1, 2])
# 将池化结果reshape成原来的形状
pooled = tf.reshape(pooled, [-1, 1, 1, channels])
pooled = tf.image.resize(pooled, size=[height, width])
return pooled
# 定义CBAM模块
class CBAM(tf.keras.layers.Layer):
def __init__(self, reduction_ratio=16):
super(CBAM, self).__init__()
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
channels = input_shape[-1]
# 定义channel attention
self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
self.max_pool = tf.keras.layers.GlobalMaxPooling2D()
self.fc1 = tf.keras.layers.Dense(units=channels // self.reduction_ratio, activation='relu')
self.fc2 = tf.keras.layers.Dense(units=channels, activation='sigmoid')
# 定义spatial attention
self.conv = tf.keras.layers.Conv2D(filters=1, kernel_size=7, strides=1, padding='same', activation='sigmoid')
def call(self, inputs):
# channel attention
avg_pool = self.avg_pool(inputs)
max_pool = self.max_pool(inputs)
x = tf.concat([avg_pool, max_pool], axis=-1)
x = self.fc1(x)
x = self.fc2(x)
x = tf.expand_dims(x, axis=1)
x = tf.expand_dims(x, axis=1)
x = inputs * x
# spatial attention
y = self.conv(inputs)
y = tf.tile(y, [1, 1, 1, inputs.shape[-1]])
x = x * y
# 随机池化
x = random_pooling(x)
return x
# 测试CBAM模块
inputs = tf.keras.Input(shape=(224, 224, 3))
x = CBAM()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=x)
model.summary()
# 使用CBAM模块
inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(inputs)
x = CBAM()(x)
x = tf.keras.layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(x)
x = CBAM()(x)
x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(x)
x = CBAM()(x)
x = tf.keras.layers.Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(x)
x = CBAM()(x)
x = tf.keras.layers.Conv2D(filters=1024, kernel_size=3, strides=1, padding='same')(x)
x = CBAM()(x)
model = tf.keras.Model(inputs=inputs, outputs=x)
model.summary()