lstm的简单例子python
时间: 2023-09-28 18:07:40 浏览: 71
使用lstm的python简单实例——mnist
下面是一个简单的LSTM模型的Python代码示例:
```python
import numpy as np
import tensorflow as tf
# 定义输入数据和标签
x_data = np.array([[[1], [2], [5], [6]], [[2], [3], [6], [7]], [[3], [4], [7], [8]]])
y_data = np.array([[7], [8], [9]])
# 定义LSTM模型
input_dim = 1
hidden_dim = 10
output_dim = 1
x = tf.placeholder(tf.float32, [None, None, input_dim])
y = tf.placeholder(tf.float32, [None, output_dim])
lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_dim)
outputs, states = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32)
# 将LSTM输出转换为输出维度
output = tf.layers.dense(outputs[:, -1, :], output_dim)
# 定义损失函数和优化器
loss = tf.reduce_mean(tf.square(output - y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)
# 开始训练模型
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(1000):
_, l = sess.run([optimizer, loss], feed_dict={x: x_data, y: y_data})
if i % 100 == 0:
print("Step", i, "Loss", l)
# 预测模型输出
test_x = np.array([[[4], [5], [8], [9]]])
prediction = sess.run(output, feed_dict={x: test_x})
print("Prediction", prediction)
```
这个例子使用LSTM模型来预测一个时间序列的下一个数字。输入数据是三个长度为四的序列,输出数据是一个长度为一的数字。在训练时,使用Adam优化器和均方误差损失函数来训练模型。最后,使用训练好的模型对一个新的长度为四的序列进行预测,并输出预测结果。
阅读全文