用tensorflow框架,写一个LeNet分类网络
时间: 2023-03-25 10:02:31 浏览: 109
基于tensorflow的图像处理项目:keras的API编写LeNet5网络来做mnist分类.zip
你好,我可以回答你的问题。以下是使用 TensorFlow 框架编写 LeNet 分类网络的代码示例:
```python
import tensorflow as tf
# 定义 LeNet 网络结构
def LeNet(x):
# 第一层卷积层
conv1 = tf.layers.conv2d(inputs=x, filters=6, kernel_size=[5, 5], padding="valid", activation=tf.nn.relu)
# 第一层池化层
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
# 第二层卷积层
conv2 = tf.layers.conv2d(inputs=pool1, filters=16, kernel_size=[5, 5], padding="valid", activation=tf.nn.relu)
# 第二层池化层
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
# 将卷积层输出的特征图转换为一维向量
flatten = tf.layers.flatten(pool2)
# 第一层全连接层
fc1 = tf.layers.dense(inputs=flatten, units=120, activation=tf.nn.relu)
# 第二层全连接层
fc2 = tf.layers.dense(inputs=fc1, units=84, activation=tf.nn.relu)
# 输出层
logits = tf.layers.dense(inputs=fc2, units=10)
return logits
# 定义输入占位符
x = tf.placeholder(tf.float32, [None, 32, 32, 1])
y = tf.placeholder(tf.int64, [None])
# 加载 MNIST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
# 将数据集转换为 LeNet 所需的格式
x_train = mnist.train.images.reshape(-1, 28, 28, 1)
y_train = mnist.train.labels
x_test = mnist.test.images.reshape(-1, 28, 28, 1)
y_test = mnist.test.labels
# 定义损失函数和优化器
logits = LeNet(x)
loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)
train_op = tf.train.AdamOptimizer().minimize(loss)
# 定义评估指标
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, axis=1), y), tf.float32))
# 训练模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(100):
_, loss_val, acc_val = sess.run([train_op, loss, accuracy], feed_dict={x: x_train, y: y_train})
if i % 100 == :
print("Step %d, loss = %.4f, accuracy = %.2f%%" % (i, loss_val, acc_val * 100))
test_acc = sess.run(accuracy, feed_dict={x: x_test, y: y_test})
print("Test accuracy = %.2f%%" % (test_acc * 100))
```
希望这个代码示例能够帮助到你。如果你有其他问题,欢迎继续提问。
阅读全文