tensorflow2 残差网络
时间: 2024-01-31 18:11:03 浏览: 25
以下是使用TensorFlow 2实现残差网络的示例代码:
```python
import tensorflow as tf
from tensorflow.keras import layers
# 定义残差块
class ResidualBlock(tf.keras.Model):
def __init__(self, filters, strides=1, activation='relu'):
super(ResidualBlock, self).__init__()
self.conv1 = layers.Conv2D(filters, 3, strides=strides, padding='same')
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2D(filters, 3, strides=1, padding='same')
self.bn2 = layers.BatchNormalization()
self.activation = layers.Activation(activation)
self.shortcut = tf.keras.Sequential()
if strides != 1:
self.shortcut.add(layers.Conv2D(filters, 1, strides=strides))
self.shortcut.add(layers.BatchNormalization())
def call(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
shortcut = self.shortcut(inputs)
x = layers.add([x, shortcut])
x = self.activation(x)
return x
# 构建残差网络模型
class ResNet(tf.keras.Model):
def __init__(self, num_classes):
super(ResNet, self).__init__()
self.conv = layers.Conv2D(64, 7, strides=2, padding='same')
self.bn = layers.BatchNormalization()
self.activation = layers.Activation('relu')
self.pool = layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')
self.res_blocks = tf.keras.Sequential([
ResidualBlock(64, strides=1),
ResidualBlock(64, strides=1),
ResidualBlock(64, strides=1),
])
self.global_avg_pool = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(num_classes, activation='softmax')
def call(self, inputs):
x = self.conv(inputs)
x = self.bn(x)
x = self.activation(x)
x = self.pool(x)
x = self.res_blocks(x)
x = self.global_avg_pool(x)
x = self.fc(x)
return x
# 创建残差网络模型实例
model = ResNet(num_classes=10)
# 编译模型
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
# 训练模型
model.fit(x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test))
# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('Test accuracy:', test_acc)
```