帮我写Mnist数据集分类的代码
时间: 2023-03-19 07:21:31 浏览: 109
Chinese MNIST数据集
好的,以下是一个基于Python的Mnist数据集分类的简单代码:
```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist
# 加载Mnist数据集并拆分为训练集和测试集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 将像素值缩放到0到1之间
train_images = train_images / 255.0
test_images = test_images / 255.0
# 定义模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)
])
# 编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=10)
# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('\nTest accuracy:', test_acc)
```
这个代码使用Tensorflow框架,首先加载Mnist数据集,然后将像素值缩放到0到1之间。接着定义了一个简单的神经网络模型,包含一个输入层、一个隐藏层和一个输出层。其中输入层为一个28x28的矩阵(表示每张图片的像素值),隐藏层为128个神经元,激活函数为ReLU,输出层为10个神经元(表示0到9这10个数字),没有激活函数。接着编译模型,使用Adam优化器、交叉熵损失函数和准确率指标。最后训练模型,使用训练集训练10个周期,并评估模型在测试集上的准确率。
阅读全文