加载MNIST数据集,查看数据的基本情况。 构架神经网络模型,识别手写数字,并评估模型效果。写出完整的python代码
时间: 2024-12-17 22:23:52 浏览: 14
在Python中,我们可以使用`tensorflow`库加载MNIST数据集,这是一个经典的数据集,包含60000张训练图片和10000张测试图片,用于识别手写数字。以下是加载数据并构建简单神经网络模型的基本步骤:
首先,安装必要的库(如果尚未安装):
```bash
pip install tensorflow numpy mnist
```
然后,导入所需的模块并加载数据:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化像素值到0-1之间
x_train, x_test = x_train / 255.0, x_test / 255.0
# 将标签转换为one-hot编码
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
# 数据预处理:调整形状并增加维度
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
```
接下来,我们构建一个简单的卷积神经网络(CNN)模型:
```python
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
```
然后,编译模型、指定损失函数、优化器和评价指标:
```python
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
```
训练模型:
```python
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
```
最后,评估模型性能:
```python
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")
```
完整代码如下:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
# ... (同上部分)
# 建立模型
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")
#
阅读全文