关于rnn模型预测的python代码
时间: 2023-10-04 20:13:46 浏览: 151
RNN_股票预测,rnn股票预测,Python
5星 · 资源好评率100%
以下是一个简单的RNN模型的Python代码示例,用于时间序列预测:
```python
import numpy as np
import tensorflow as tf
# 设置随机种子
tf.set_random_seed(101)
# 准备数据
X = np.linspace(0, 50, 501)
y = np.sin(X)
# 设置时间步长
time_steps = 30
# 创建训练集
X_train = []
y_train = []
for i in range(len(y) - time_steps):
X_train.append(y[i:i+time_steps])
y_train.append(y[i+time_steps])
# 转换为 numpy 数组
X_train = np.array(X_train)
y_train = np.array(y_train)
# 创建模型
inputs = tf.placeholder(tf.float32, [None, time_steps])
targets = tf.placeholder(tf.float32, [None, 1])
# 定义 RNN 层
num_neurons = 100
cell = tf.contrib.rnn.BasicRNNCell(num_units=num_neurons, activation=tf.nn.relu)
outputs, states = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)
# 定义输出层
num_outputs = 1
outputs = tf.layers.dense(outputs[:, -1], num_outputs)
# 定义损失函数和优化器
learning_rate = 0.001
loss = tf.losses.mean_squared_error(targets, outputs)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train = optimizer.minimize(loss)
# 初始化变量
init = tf.global_variables_initializer()
# 定义批次大小和迭代次数
batch_size = 50
num_epochs = 1000
# 创建会话
with tf.Session() as sess:
sess.run(init)
for epoch in range(num_epochs):
# 随机抽取批次
rand_index = np.random.choice(len(X_train), size=batch_size)
X_batch = X_train[rand_index]
y_batch = y_train[rand_index].reshape(-1, 1)
# 训练模型
sess.run(train, feed_dict={inputs: X_batch, targets: y_batch})
# 每 100 次迭代输出一次结果
if epoch % 100 == 0:
mse = loss.eval(feed_dict={inputs: X_batch, targets: y_batch})
print("Epoch: {}, MSE: {}".format(epoch, mse))
# 使用训练好的模型进行预测
y_pred = sess.run(outputs, feed_dict={inputs: X_train})
# 绘制预测结果和实际结果的图像
import matplotlib.pyplot as plt
plt.plot(X, y, label="Actual")
plt.plot(X[:-time_steps], y_pred, label="Predicted")
plt.legend()
plt.show()
```
以上代码是一个简单的RNN模型,用于预测一个正弦函数。首先,我们通过 `np.linspace()` 函数生成了一个包含 501 个点的正弦函数作为数据集。然后,我们设置了时间步长,并用 `for` 循环创建了训练集。接下来,我们定义了 RNN 层和输出层,然后定义了损失函数和优化器。之后,我们使用 TensorFlow 的会话来训练模型,并输出了每个 epoch 的 MSE(均方误差)。最后,我们使用训练好的模型来预测整个数据集,并使用 Matplotlib 绘制了预测结果和实际结果的图像。
阅读全文