ResNet网络模型实现MNIST手写数字识别的代码示例
时间: 2023-05-25 10:03:59 浏览: 337
以下是使用ResNet网络模型实现MNIST手写数字识别的代码示例:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理:将像素值缩放到0到1之间,并增加一维通道数
x_train = x_train.astype('float32') / 255.0
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.astype('float32') / 255.0
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
# 定义ResNet网络模型
def residual_block(inputs, filters, strides=1):
shortcut = inputs
# 第一个卷积层
x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)
# 第二个卷积层
x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization()(x)
# 如果输入和输出的维度不同,则使用1x1卷积调整维度
if shortcut.shape[-1] != filters:
shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='same')(shortcut)
shortcut = BatchNormalization()(shortcut)
# 将残差块的输出与输入相加,构成下一层的输入
x = Add()([x, shortcut])
x = ReLU()(x)
return x
def ResNet(input_shape=(28, 28, 1), num_classes=10):
inputs = Input(shape=input_shape)
# 第一层卷积
x = Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)
# 残差块组1
x = residual_block(x, filters=64, strides=1)
x = residual_block(x, filters=64, strides=1)
x = residual_block(x, filters=64, strides=1)
# 残差块组2
x = residual_block(x, filters=128, strides=2)
x = residual_block(x, filters=128, strides=1)
x = residual_block(x, filters=128, strides=1)
# 残差块组3
x = residual_block(x, filters=256, strides=2)
x = residual_block(x, filters=256, strides=1)
x = residual_block(x, filters=256, strides=1)
# 残差块组4
x = residual_block(x, filters=512, strides=2)
x = residual_block(x, filters=512, strides=1)
x = residual_block(x, filters=512, strides=1)
# 全局平均池化
x = GlobalAveragePooling2D()(x)
# 全连接层
x = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=inputs, outputs=x)
return model
# 创建ResNet模型
model = ResNet(num_classes=10)
# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test))
```
阅读全文