用rnn模型预测正弦函数,且可视化损失函数,隐藏层
时间: 2024-03-10 08:49:04 浏览: 84
RNN简单应用—-预测正弦函数
以下是使用PyTorch实现RNN预测正弦函数并可视化损失函数和隐藏层的代码示例:
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 定义RNN模型
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
# 定义参数
input_size = 1
hidden_size = 32
output_size = 1
batch_size = 1
seq_len = 100
learning_rate = 0.01
epochs = 1000
# 初始化模型和损失函数
rnn = RNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
# 生成正弦函数数据
time_steps = np.linspace(0, 2*np.pi, seq_len+1)
data = np.sin(time_steps)
data.resize((seq_len+1, 1))
# 训练模型并可视化损失函数和隐藏层
losses = []
hiddens = []
for epoch in range(1, epochs+1):
hidden = rnn.initHidden()
rnn.zero_grad()
loss = 0
for i in range(seq_len):
input = torch.tensor(data[i], dtype=torch.float32).view(1, 1, -1)
target = torch.tensor(data[i+1], dtype=torch.float32).view(1, -1)
output, hidden = rnn(input, hidden)
loss += criterion(output, target)
hiddens.append(hidden.detach().numpy().flatten())
loss.backward()
for p in rnn.parameters():
p.data.add_(-learning_rate, p.grad.data)
losses.append(loss.item())
# 可视化损失函数和隐藏层
plt.figure()
plt.plot(losses)
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.show()
plt.figure()
plt.imshow(np.array(hiddens).T, cmap='viridis', origin='lower')
plt.colorbar()
plt.xlabel('Time Step')
plt.ylabel('Hidden Size')
plt.title('Hidden Layer')
plt.show()
```
该代码使用了一个简单的RNN模型,以正弦函数为例,训练模型并可视化损失函数和隐藏层。在代码中,我们首先定义了一个RNN模型和训练参数,生成正弦函数数据,然后训练模型并计算损失函数,并将损失函数和隐藏层可视化。在训练过程中,我们遍历每个时间步,计算输出和隐层状态,并将隐层状态添加到列表中,用于可视化隐藏层。最后,我们使用Matplotlib库可视化损失函数和隐藏层。
阅读全文