resnet18 tensorflow版本代码
时间: 2023-03-26 11:02:51 浏览: 139
以下是 ResNet18 的 TensorFlow 版本代码:
```python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, ZeroPadding2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
def identity_block(X, f, filters):
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(1, 1), padding='valid')(X)
X = BatchNormalization(axis=3)(X)
X = Activation('relu')(X)
X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same')(X)
X = BatchNormalization(axis=3)(X)
X = Activation('relu')(X)
X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid')(X)
X = BatchNormalization(axis=3)(X)
X = Add()([X, X_shortcut])
X = Activation('relu')(X)
return X
def convolutional_block(X, f, filters, s=2):
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(s, s), padding='valid')(X)
X = BatchNormalization(axis=3)(X)
X = Activation('relu')(X)
X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same')(X)
X = BatchNormalization(axis=3)(X)
X = Activation('relu')(X)
X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid')(X)
X = BatchNormalization(axis=3)(X)
X_shortcut = Conv2D(filters=F3, kernel_size=(1, 1), strides=(s, s), padding='valid')(X_shortcut)
X_shortcut = BatchNormalization(axis=3)(X_shortcut)
X = Add()([X, X_shortcut])
X = Activation('relu')(X)
return X
def ResNet18(input_shape=(224, 224, 3), classes=100):
X_input = Input(input_shape)
X = ZeroPadding2D((3, 3))(X_input)
X = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2))(X)
X = BatchNormalization(axis=3)(X)
X = Activation('relu')(X)
X = MaxPooling2D((3, 3), strides=(2, 2))(X)
X = convolutional_block(X, f=3, filters=[64, 64, 256], s=1)
X = identity_block(X, f=3, filters=[64, 64, 256])
X = identity_block(X, f=3, filters=[64, 64, 256])
X = convolutional_block(X, f=3, filters=[128, 128, 512], s=2)
X = identity_block(X, f=3, filters=[128, 128, 512])
X = identity_block(X, f=3, filters=[128, 128, 512])
X = identity_block(X, f=3, filters=[128, 128, 512])
X = convolutional_block(X, f=3, filters=[256, 256, 1024], s=2)
X = identity_block(X, f=3, filters=[256, 256, 1024])
X = identity_block(X, f=3, filters=[256, 256, 1024])
X = identity_block(X, f=3, filters=[256, 256, 1024])
X = identity_block(X, f=3, filters=[256, 256, 1024])
X = identity_block(X, f=3, filters=[256, 256, 1024])
X = convolutional_block(X, f=3, filters=[512, 512, 2048], s=2)
X = identity_block(X, f=3, filters=[512, 512, 2048])
X = identity_block(X, f=3, filters=[512, 512, 2048])
X = AveragePooling2D((2, 2))(X)
X = Flatten()(X)
X = Dense(classes, activation='softmax')(X)
model = Model(inputs=X_input, outputs=X, name='ResNet18')
return model
```
希望对你有帮助!
阅读全文