帮我用写一个rnn预测短时间序列的pyhton代码,数据集样本数是64个
时间: 2024-05-08 22:17:24 浏览: 20
下面是一个简单的RNN预测短时间序列的Python代码,数据集样本数为64个:
```python
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN
# 生成数据集
seq_length = 10
data = np.sin(np.linspace(0, 2*np.pi, 64))
X = []
y = []
for i in range(len(data)-seq_length):
X.append(data[i:i+seq_length])
y.append(data[i+seq_length])
X = np.array(X)
y = np.array(y)
# 构建模型
model = Sequential()
model.add(SimpleRNN(16, input_shape=(seq_length, 1)))
model.add(Dense(1, activation='linear'))
model.compile(loss='mse', optimizer='adam')
# 训练模型
model.fit(X.reshape(-1, seq_length, 1), y, batch_size=16, epochs=100)
# 预测结果
y_pred = model.predict(X.reshape(-1, seq_length, 1))
# 可视化结果
plt.plot(data, label='True')
plt.plot(range(seq_length, len(y_pred)+seq_length), y_pred, label='Prediction')
plt.legend()
plt.show()
```
在这个代码中,我们生成了一个包含64个样本的数据集。每个样本是长度为10的序列,其对应的目标值是序列中的下一个元素。我们使用了一个简单的RNN模型来预测目标值,并使用均方误差作为损失函数进行训练。最后,我们可视化了真实序列和预测序列的结果。