使用TensorFlow(Keras)搭建卷积神经网络ResNet,实现手写数字识别
时间: 2023-06-20 20:06:27 浏览: 130
好的,让我们开始吧!
首先,我们需要导入必要的库和数据集。这里我们使用MNIST数据集,它包含手写数字的图像和对应的标签。
```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
# 导入数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
```
接下来,我们将对图像进行预处理,将像素值缩放到0到1之间,并将标签转换为one-hot编码。
```python
# 将像素值缩放到0到1之间
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
# 将标签转换为one-hot编码
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
```
然后,我们将定义ResNet的结构。这里我们使用了经典的ResNet-18结构,包括卷积层、批归一化层、ReLU激活函数、最大池化层、残差块和全局平均池化层等组件。
```python
def conv_block(inputs, filters, strides):
x = layers.Conv2D(filters, 3, strides=strides, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
return x
def identity_block(inputs, filters):
x = layers.Conv2D(filters, 3, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Add()([inputs, x])
x = layers.ReLU()(x)
return x
def resnet18():
inputs = keras.Input(shape=(28, 28, 1))
x = conv_block(inputs, 64, strides=1)
x = identity_block(x, 64)
x = identity_block(x, 64)
x = conv_block(x, 128, strides=2)
x = identity_block(x, 128)
x = identity_block(x, 128)
x = conv_block(x, 256, strides=2)
x = identity_block(x, 256)
x = identity_block(x, 256)
x = conv_block(x, 512, strides=2)
x = identity_block(x, 512)
x = identity_block(x, 512)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(10, activation="softmax")(x)
return keras.Model(inputs, outputs)
```
最后,我们将编译模型并开始训练。这里我们使用交叉熵损失函数和Adam优化器。
```python
# 创建模型
model = resnet18()
# 编译模型
model.compile(
loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
# 训练模型
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.1)
# 在测试集上评估模型
model.evaluate(x_test, y_test)
```
恭喜!现在你已经成功地使用TensorFlow(Keras)搭建了卷积神经网络ResNet,实现了手写数字识别。
阅读全文