基于tensorflow使用lenet-5实现对mnist手写数字识别分类
时间: 2023-07-29 07:08:36 浏览: 144
好的,以下是基于 TensorFlow 使用 LeNet-5 实现对 MNIST 手写数字识别分类的步骤:
1.导入必要的库:
```
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
```
2.定义 LeNet-5 模型:
```
def LeNet5(x):
# 定义第一层卷积层,使用 5x5 的卷积核,32 个卷积核,激活函数使用 relu
conv1 = tf.layers.conv2d(x, filters=32, kernel_size=[5, 5], activation=tf.nn.relu, padding='same')
# 定义第一层池化层,使用 2x2 的池化核,步长为 2
pool1 = tf.layers.max_pooling2d(conv1, pool_size=[2, 2], strides=2)
# 定义第二层卷积层,使用 5x5 的卷积核,64 个卷积核,激活函数使用 relu
conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[5, 5], activation=tf.nn.relu, padding='same')
# 定义第二层池化层,使用 2x2 的池化核,步长为 2
pool2 = tf.layers.max_pooling2d(conv2, pool_size=[2, 2], strides=2)
# 将池化层的输出展开成一维向量
flatten = tf.layers.flatten(pool2)
# 定义第一层全连接层,输出维度为 512,激活函数使用 relu
fc1 = tf.layers.dense(flatten, units=512, activation=tf.nn.relu)
# 定义第二层全连接层,输出维度为 10,即分类数目
logits = tf.layers.dense(fc1, units=10)
return logits
```
3.定义输入输出:
```
# 定义输入输出
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])
```
4.读取数据:
```
# 读取 MNIST 数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
```
5.定义模型超参数:
```
# 定义模型超参数
learning_rate = 0.001
batch_size = 128
num_epochs = 10
```
6.定义损失函数和优化器:
```
# 定义损失函数和优化器
logits = LeNet5(x)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
```
7.定义准确率:
```
# 定义准确率
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
```
8.训练模型:
```
# 训练模型
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(num_epochs):
num_batches = mnist.train.num_examples // batch_size
for i in range(num_batches):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape(-1, 28, 28, 1)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
# 计算训练集准确率和损失
train_acc, train_loss = sess.run([accuracy, loss], feed_dict={x: mnist.train.images.reshape(-1, 28, 28, 1), y: mnist.train.labels})
# 计算测试集准确率
test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images.reshape(-1, 28, 28, 1), y: mnist.test.labels})
print("Epoch:", epoch + 1, "Train Accuracy:", train_acc, "Test Accuracy:", test_acc, "Train Loss:", train_loss)
```
9.测试模型:
```
# 测试模型
test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images.reshape(-1, 28, 28, 1), y: mnist.test.labels})
print("Test Accuracy:", test_acc)
```
希望这个例子能够帮助你了解如何使用 TensorFlow 实现 LeNet-5 模型对 MNIST 手写数字进行分类。
阅读全文