写一段rnn模型对时间序列进行预测的代码
时间: 2024-01-22 07:18:57 浏览: 43
以下是一个简单的RNN模型,用于预测一个时间序列的未来值:
```
import numpy as np
import tensorflow as tf
# 生成时间序列数据
def generate_time_series(batch_size, n_steps):
freq1, freq2, offsets1, offsets2 = np.random.rand(4, batch_size, 1)
time = np.linspace(0, 1, n_steps)
series = 0.5 * np.sin((time - offsets1) * (freq1 * 10 + 10)) # wave 1
series += 0.2 * np.sin((time - offsets2) * (freq2 * 20 + 20)) # wave 2
series += 0.1 * (np.random.rand(batch_size, n_steps) - 0.5) # noise
return series[..., np.newaxis].astype(np.float32)
n_steps = 50
n_inputs = 1
n_neurons = 100
n_outputs = 1
# 定义模型
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_outputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
logits = tf.layers.dense(states, n_outputs)
loss = tf.reduce_mean(tf.square(logits - y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
training_op = optimizer.minimize(loss)
# 训练模型
n_iterations = 1000
batch_size = 50
with tf.Session() as sess:
tf.global_variables_initializer().run()
for iteration in range(n_iterations):
X_batch, y_batch = generate_time_series(batch_size, n_steps)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
if iteration % 100 == 0:
mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
print(iteration, "\tMSE:", mse)
# 使用训练好的模型进行预测
X_new, _ = generate_time_series(1, n_steps)
y_pred = sess.run(logits, feed_dict={X: X_new.reshape(1, n_steps, n_inputs)})
```
该模型使用了一个基本的RNN单元,并在最后加上了一个全连接层,将最后一个时间步的输出转换成一个预测值。在训练过程中,我们使用Adam优化器来最小化均方误差。在训练完后,我们使用训练好的模型来预测一个新的时间序列。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://img-home.csdnimg.cn/images/20210720083646.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)