RNN代码
时间: 2023-07-01 20:19:09 浏览: 77
以下是一个简单的基于Python和Keras的RNN代码示例:
```python
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
import numpy as np
# 定义训练数据
X_train = np.array([[[1], [2], [3], [4], [5]]])
y_train = np.array([[6]])
# 定义 RNN 模型
model = Sequential()
model.add(SimpleRNN(units=32, input_shape=(5, 1)))
model.add(Dense(units=1))
# 编译模型
model.compile(optimizer='rmsprop', loss='mse')
# 训练模型
model.fit(X_train, y_train, epochs=10, batch_size=1)
# 预测
X_test = np.array([[[6], [7], [8], [9], [10]]])
y_pred = model.predict(X_test)
print(y_pred)
```
在上面的例子中,我们定义了一个非常简单的序列数据,包含了一个输入序列 `[1, 2, 3, 4, 5]` 和一个输出标签 `6`。我们使用 Keras 搭建了一个仅包含一个 SimpleRNN 层和一个 Dense 层的模型,并使用均方误差损失函数和 RMSprop 优化器进行训练。在训练完成后,我们用这个模型对输入序列 `[6, 7, 8, 9, 10]` 进行预测,并打印出了预测结果。
阅读全文