rnn wiring
时间: 2024-12-30 09:33:43 浏览: 7
### RNN架构与连接解释
#### 传统RNN的内部结构
循环神经网络(Recurrent Neural Network, RNN)是一种用于处理序列数据的神经网络。不同于前馈神经网络,其中的信息仅沿单一方向传播,在RNN中,信息可以在多个时间步之间传递。这种特性使得RNN能够捕捉到输入数据的时间依赖关系。
在传统的RNN架构中,隐藏层的状态不仅取决于当前时刻的输入,还受到之前时刻状态的影响。具体来说,对于给定的一个时间步$t$,其隐藏状态$h_t$由两部分组成:一部分来自当前时刻的输入$x_t$;另一部分则来自于上一时刻的隐藏状态$h_{t-1}$[^1]。
#### 计算过程
假设有一个简单的RNN单元,则该单元内的计算可以表示如下:
$$ h_t = \sigma(W_h x_t + U_h h_{t-1} + b_h) $$
这里,
- $h_t$ 表示第$t$个时间步下的隐藏状态;
- $\sigma(\cdot)$ 是激活函数,通常采用双曲正切(tanh)或Sigmoid函数;
- $W_h$, $U_h$ 分别代表权重矩阵;
- $b_h$ 则是偏置向量[^2]。
为了更直观理解这一机制,下面给出一段Python代码来实现上述公式:
```python
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(SimpleRNN, self).__init__()
self.rnn_cell = nn.RNNCell(input_size=input_size,
hidden_size=hidden_size)
def forward(self, inputs, initial_hidden_state=None):
batch_size, seq_len, _ = inputs.shape
if initial_hidden_state is None:
ht = torch.zeros(batch_size, self.hidden_size).to(inputs.device)
else:
ht = initial_hidden_state
outputs = []
for t in range(seq_len):
xt = inputs[:, t, :]
ht = self.rnn_cell(xt, ht)
outputs.append(ht.unsqueeze(0))
return torch.cat(outputs), ht
```
此代码定义了一个基于`nn.RNNCell`类构建的基础版单层RNN模型,并实现了前向传播逻辑。通过迭代遍历整个序列长度(`seq_len`),依次更新每个时间步上的隐藏状态($ht$),最终返回所有时间步的结果以及最后一个时间步的隐藏状态。
阅读全文