class BasicBlock(layers.Layer): expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super(BasicBlock, self).__init__() # 1. BasicBlock模块中的共有2个卷积;BasicBlock模块中的第1个卷积层; self.conv1 = regularized_padded_conv(out_channels, kernel_size=3, strides=stride) self.bn1 = layers.BatchNormalization() # 2. 第2个;第1个卷积如果做stride就会有一个下采样,在这个里面就不做下采样了。这一块始终保持size一致,把stride固定为1 self.conv2 = regularized_padded_conv(out_channels, kernel_size=3, strides=1) self.bn2 = layers.BatchNormalization() ############################### 注意力机制 ############################### self.ca = ChannelAttention(out_channels) self.sa = SpatialAttention() # 3. 判断stride是否等于1,如果为1就是没有降采样。 if stride != 1 or in_channels != self.expansion * out_channels: self.shortcut = Sequential([regularized_padded_conv(self.expansion * out_channels, kernel_size=1, strides=stride), layers.BatchNormalization()]) else: self.shortcut = lambda x, _: x def call(self, inputs, training=False): out = self.conv1(inputs) out = self.bn1(out, training=training) out = tf.nn.relu(out) out = self.conv2(out) out = self.bn2(out, training=training) ############################### 注意力机制 ############################### out = self.ca(out) * out out = self.sa(out) * out out = out + self.shortcut(inputs, training) out = tf.nn.relu(out) return out 改为tensorflow版本
时间: 2024-01-15 22:05:27 浏览: 98
请注意,这段代码已经是tensorflow版本的代码,只需要按照tensorflow的API进行使用即可。以下是改写后的代码:
```python
class BasicBlock(tf.keras.layers.Layer):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
# 1. BasicBlock模块中的共有2个卷积;BasicBlock模块中的第1个卷积层;
self.conv1 = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=stride, padding='same',
kernel_regularizer=tf.keras.regularizers.l2(1e-4))
self.bn1 = tf.keras.layers.BatchNormalization()
# 2. 第2个;第1个卷积如果做stride就会有一个下采样,在这个里面就不做下采样了。这一块始终保持size一致,把stride固定为1
self.conv2 = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=1, padding='same',
kernel_regularizer=tf.keras.regularizers.l2(1e-4))
self.bn2 = tf.keras.layers.BatchNormalization()
############################### 注意力机制 ###############################
self.ca = ChannelAttention(out_channels)
self.sa = SpatialAttention()
# 3. 判断stride是否等于1,如果为1就是没有降采样。
if stride != 1 or in_channels != self.expansion * out_channels:
self.shortcut = tf.keras.Sequential([
tf.keras.layers.Conv2D(self.expansion * out_channels, kernel_size=1, strides=stride, padding='same',
kernel_regularizer=tf.keras.regularizers.l2(1e-4)),
tf.keras.layers.BatchNormalization()
])
else:
self.shortcut = lambda x, _: x
def call(self, inputs, training=False):
out = self.conv1(inputs)
out = self.bn1(out, training=training)
out = tf.nn.relu(out)
out = self.conv2(out)
out = self.bn2(out, training=training)
############################### 注意力机制 ###############################
out = self.ca(out) * out
out = self.sa(out) * out
out = out + self.shortcut(inputs, training)
out = tf.nn.relu(out)
return out
```
阅读全文