写一个残差网络的python代码
时间: 2024-02-07 08:04:08 浏览: 69
好的,以下是一个简单的残差网络实现的 Python 代码示例:
```python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add
def residual_block(x, filters, strides=(1, 1)):
# 定义残差块
shortcut = x
# 第一层卷积,加 BN 和 ReLU
x = Conv2D(filters, (3, 3), strides=strides, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
# 第二层卷积,加 BN
x = Conv2D(filters, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
# 如果输入和输出的维度不同,就对 shortcut 进行卷积,将维度匹配
if strides != (1, 1):
shortcut = Conv2D(filters, (1, 1), strides=strides, padding='same')(shortcut)
# 相加
x = Add()([x, shortcut])
x = ReLU()(x)
return x
# 定义一个 34 层的残差网络
inputs = Input(shape=(224, 224, 3))
x = Conv2D(64, (7, 7), strides=(2, 2), padding='same')(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)
x = residual_block(x, filters=64)
x = residual_block(x, filters=64)
x = residual_block(x, filters=64, strides=(2, 2))
x = residual_block(x, filters=128)
x = residual_block(x, filters=128)
x = residual_block(x, filters=128)
x = residual_block(x, filters=128, strides=(2, 2))
x = residual_block(x, filters=256)
x = residual_block(x, filters=256)
x = residual_block(x, filters=256)
x = residual_block(x, filters=256)
x = residual_block(x, filters=256)
x = residual_block(x, filters=256, strides=(2, 2))
x = residual_block(x, filters=512)
x = residual_block(x, filters=512)
x = residual_block(x, filters=512)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = Dense(1, activation='sigmoid')(x)
model = tf.keras.models.Model(inputs, outputs)
model.summary()
```
这段代码实现了一个 34 层的残差网络。给定输入图片的大小为 224x224,输出二分类预测结果。每个残差块包含两个 3x3 的卷积层和一个 shortcut(如果输入和输出大小不同,shortcut 会进行卷积来匹配维度)。在每个残差块的卷积层之间,使用了 BN 和 ReLU 激活函数。最终输出一个 sigmoid 激活函数的输出,表示预测概率。
阅读全文