给定sin函数的前三个值,预测后一个值,利用numpy、math工具包进行正弦值序列的生成np.arange(0,3,.01)并划分训练集和测试集的范围,对数据进行reshape,然后编写RNN模型(包含RNN模型和一层全连接层),在进行优化器和损失函数定义,采用adam优化器,最后进行预测,绘出预测值和真实值的对比图,使用pytorch完成
时间: 2024-03-18 18:43:51 浏览: 84
Python Numpy:找到list中的np.nan值方法
5星 · 资源好评率100%
好的,下面给出使用 PyTorch 实现的代码。
首先,我们需要导入必要的包。
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
```
然后,生成 sin 函数序列并划分训练集和测试集。
```python
# 生成 sin 函数序列
x = np.arange(0, 3, 0.01)
y = np.sin(x)
# 划分训练集和测试集
split = int(len(x) * 0.8)
x_train, y_train = x[:split], y[:split]
x_test, y_test = x[split:], y[split:]
# 对数据进行 reshape
x_train = x_train.reshape(-1, 1, 1)
y_train = y_train.reshape(-1, 1)
x_test = x_test.reshape(-1, 1, 1)
y_test = y_test.reshape(-1, 1)
```
接着,我们可以定义 RNN 模型。
```python
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, _ = self.rnn(x)
out = out[:, -1, :]
out = self.fc(out)
return out
```
定义完模型后,我们可以将模型实例化,并定义优化器和损失函数。
```python
# 实例化模型
model = RNN(1, 10, 1)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
```
然后,我们可以使用训练集来训练模型。
```python
# 训练模型
loss_list = []
for epoch in range(100):
model.train()
optimizer.zero_grad()
output = model(torch.tensor(x_train, dtype=torch.float32))
loss = criterion(output, torch.tensor(y_train, dtype=torch.float32))
loss.backward()
optimizer.step()
loss_list.append(loss.item())
if epoch % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 100, loss.item()))
```
训练完成后,我们可以使用测试集进行预测,并绘制预测值和真实值的对比图。
```python
# 使用测试集进行预测
model.eval()
with torch.no_grad():
y_pred = model(torch.tensor(x_test, dtype=torch.float32))
# 绘制预测值和真实值的对比图
plt.plot(x_test.reshape(-1), y_test.reshape(-1), label='True Values')
plt.plot(x_test.reshape(-1), y_pred.numpy().reshape(-1), label='Predicted Values')
plt.legend()
plt.show()
```
最后,我们可以看到预测值和真实值的对比图。
![sin预测图](https://img-blog.csdnimg.cn/20211211174841570.png)
阅读全文