帮我写一个基于minist数据集的神经网络模型训练代码
时间: 2023-02-27 12:41:40 浏览: 154
MINIST数据集识别代码(基于CNN与FullyCNN模型)
5星 · 资源好评率100%
好的,以下是一个使用 TensorFlow 和 Keras 框架实现的简单神经网络模型代码,用于训练 MNIST 数据集:
```
import tensorflow as tf
from tensorflow import keras
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 对图像数据进行预处理
x_train = x_train.reshape(x_train.shape[0], 28 * 28) / 255.0
x_test = x_test.reshape(x_test.shape[0], 28 * 28) / 255.0
# 建立模型
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(28 * 28,)),
keras.layers.Dense(10, activation='softmax'),
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=5)
# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test Accuracy: ', test_acc)
```
这是一个非常简单的神经网络模型,它可以对 MNIST 数据集进行分类。你可以通过修改代码来定制它以适应更复杂的场景。
阅读全文