用resnet网络模型实现mnist
时间: 2023-05-25 08:03:54 浏览: 114
Resnet50卷积神经网络训练MNIST手写数字图像分类 Pytorch训练代码
5星 · 资源好评率100%
ResNet是一个非常深的深度学习网络结构,可以有效地解决梯度消失问题。在此提供通过ResNet网络模型实现MNIST手写数字识别的代码示例,如下所示:
```python
import keras
from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense
from keras.layers import Input, add, Activation, Flatten, Dropout
from keras.models import Model
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import SGD
# 定义ResNet网络模型
def residual_module(layer_in, n_filters):
merge_input = layer_in
# 第一层卷积
layer = Conv2D(n_filters, (3,3), padding='same', activation='relu')(layer_in)
# 第二层卷积
layer = Conv2D(n_filters, (3,3), padding='same', activation='linear')(layer)
# 合并
layer = add([layer, merge_input])
# 激活函数
layer_out = Activation('relu')(layer)
return layer_out
def build_resnet(input_shape, num_classes):
# 输入层
inputs = Input(shape=input_shape)
# 第一层卷积
conv1 = Conv2D(64, (3,3), padding='same', activation='linear')(inputs)
# 残差模块1
res1 = residual_module(conv1, 64)
res1 = residual_module(res1, 64)
res1 = residual_module(res1, 64)
# 残差模块2
res2 = residual_module(res1, 128)
res2 = residual_module(res2, 128)
res2 = residual_module(res2, 128)
# 残差模块3
res3 = residual_module(res2, 256)
res3 = residual_module(res3, 256)
res3 = residual_module(res3, 256)
# 池化层
pool = GlobalAveragePooling2D()(res3)
# 全连接层
fc1 = Dense(512, activation='relu')(pool)
fc2 = Dropout(0.5)(fc1)
output = Dense(num_classes, activation='softmax')(fc2)
# 定义模型
model = Model(inputs=inputs, outputs=output)
return model
# 加载MNIST数据集
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 数据预处理
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32') / 255
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32') / 255
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
# 构建ResNet网络模型
model = build_resnet((28,28,1), 10)
model.summary()
# 编译模型
learning_rate = 0.1
sgd = SGD(lr=learning_rate, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# 训练模型
epochs = 10
batch_size = 128
history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(X_test, y_test))
# 评估模型
loss, accuracy = model.evaluate(X_test, y_test)
print("Accuracy: {:.2f}%".format(accuracy * 100))
# 绘制准确率和损失值的曲线
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='validation accuracy')
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='validation loss')
plt.title('ResNet MNIST')
plt.xlabel('Epoch')
plt.ylabel('Accuracy/Loss')
plt.legend()
plt.show()
```
在进行更大的迭代次数后,该示例代码可使准确度达到大约98%。
阅读全文