基于mnist数据集构建MLP,并绘制训练次数和准确率的图像
时间: 2023-12-12 10:55:57 浏览: 84
以下是基于MNIST数据集构建MLP的Python代码,并绘制训练次数和准确率的图像:
```python
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# 加载MNIST数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 定义超参数
learning_rate = 0.1
num_epochs = 100
batch_size = 100
display_step = 1
# 定义输入和输出变量
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
# 定义权重和偏置变量
W1 = tf.Variable(tf.random_normal([784, 256]))
b1 = tf.Variable(tf.random_normal([256]))
W2 = tf.Variable(tf.random_normal([256, 10]))
b2 = tf.Variable(tf.random_normal([10]))
# 定义MLP模型
hidden_layer = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
output_layer = tf.add(tf.matmul(hidden_layer, W2), b2)
# 定义损失函数和优化器
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output_layer, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# 定义准确率计算方式
correct_pred = tf.equal(tf.argmax(output_layer, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# 初始化变量
init = tf.global_variables_initializer()
# 开始训练MLP模型
with tf.Session() as sess:
sess.run(init)
num_batches = int(mnist.train.num_examples / batch_size)
training_costs = []
training_accs = []
for epoch in range(num_epochs):
for i in range(num_batches):
batch_x, batch_y = mnist.train.next_batch(batch_size)
_, c, acc = sess.run([optimizer, cost, accuracy], feed_dict={x: batch_x, y: batch_y})
if epoch % display_step == 0:
training_costs.append(c)
training_accs.append(acc)
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c), "accuracy=", "{:.9f}".format(acc))
print("Optimization Finished!")
# 测试MLP模型
test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Test Accuracy:", test_acc)
# 绘制训练次数和准确率的图像
plt.plot(range(num_epochs), training_accs)
plt.xlabel('Training Epochs')
plt.ylabel('Training Accuracy')
plt.show()
```
运行代码后,将输出每个epoch的损失和准确率,并在训练结束时输出测试准确率。此外,还将绘制训练次数和准确率的图像,以便更好地了解模型的训练情况。
阅读全文