用rnn模型预测正弦函数,并可视化损失函数
时间: 2024-03-10 14:47:29 浏览: 24
以下是一个使用RNN模型预测正弦函数,并可视化损失函数的示例:
```python
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, SimpleRNN
from tensorflow.keras.optimizers import Adam
# 创建正弦函数的输入和输出序列
x = np.linspace(start=0, stop=50, num=1000)
y = np.sin(x)
# 创建序列窗口
seq_len = 50
data = []
for i in range(len(y) - seq_len):
data.append(y[i:i+seq_len])
data = np.array(data)
# 将数据集分为训练集和测试集
split = 0.8
idx = int(split * len(data))
X_train, y_train = data[:idx], data[1:idx+1]
X_test, y_test = data[idx:-1], data[idx+1:]
# 创建RNN模型
model = Sequential([
SimpleRNN(32, input_shape=(seq_len, 1)),
Dense(1)
])
model.compile(optimizer=Adam(lr=0.01), loss='mse')
# 训练模型
history = model.fit(X_train.reshape(-1, seq_len, 1), y_train,
validation_data=(X_test.reshape(-1, seq_len, 1), y_test),
epochs=100, batch_size=16)
# 可视化损失函数
plt.plot(history.history['loss'], label='Training loss')
plt.plot(history.history['val_loss'], label='Validation loss')
plt.legend()
plt.show()
# 使用训练好的模型进行预测
y_pred = model.predict(X_test.reshape(-1, seq_len, 1)).squeeze()
# 可视化预测结果
plt.plot(y_test, label='True values')
plt.plot(y_pred, label='Predicted values')
plt.legend()
plt.show()
```
在上面的代码中,我们首先生成了一个正弦函数的序列,并将其作为输入和输出序列。然后,我们创建了一个序列窗口,将输入序列划分为多个序列,并将其作为RNN的输入。我们使用一个简单的RNN模型,包含一个RNN层和一个全连接层。最后,我们使用训练好的模型预测测试集的输出序列,并可视化预测结果和损失函数变化曲线。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![py](https://img-home.csdnimg.cn/images/20210720083646.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)