nn.LSTM 内部结构是怎样的
时间: 2023-11-14 13:06:25 浏览: 64
nn.LSTM是一个长短时记忆网络,用于处理序列数据。它由一个输入门、一个遗忘门、一个输出门和一个新单元状态组成。
具体地,对于一个时间步$t$,LSTM的计算流程如下:
1. 输入门:根据输入$x_t$和上一个时间步的隐藏状态$h_{t-1}$计算一个值$i_t$,表示需要从输入中记忆多少信息。即$i_t=\sigma(W_{ii}x_t+W_{hi}h_{t-1}+b_i)$,其中$\sigma$是sigmoid函数。
2. 遗忘门:根据输入$x_t$和上一个时间步的隐藏状态$h_{t-1}$计算一个值$f_t$,表示需要从上一个时间步的单元状态中遗忘多少信息。即$f_t=\sigma(W_{if}x_t+W_{hf}h_{t-1}+b_f)$。
3. 更新单元状态:根据输入$x_t$、上一个时间步的隐藏状态$h_{t-1}$、输入门的值$i_t$和遗忘门的值$f_t$,计算一个新的单元状态$c_t$。即$c_t=f_t \odot c_{t-1} + i_t \odot \tanh(W_{ic}x_t+W_{hc}h_{t-1}+b_c)$,其中$\odot$表示逐元素相乘,$\tanh$表示双曲正切函数。
4. 输出门:根据输入$x_t$、上一个时间步的隐藏状态$h_{t-1}$和更新后的单元状态$c_t$,计算一个输出值$o_t$,表示需要输出多少信息。即$o_t=\sigma(W_{io}x_t+W_{ho}h_{t-1}+W_{co}c_t+b_o)$。
5. 隐藏状态:将更新后的单元状态$c_t$通过$\tanh$函数进行处理,并结合输出门的值$o_t$,得到隐藏状态$h_t$,即$h_t=o_t \odot \tanh(c_t)$。
这样,LSTM就完成了一个时间步的计算。在处理整个序列时,LSTM会不断重复上述计算流程,直到处理完所有时间步。
阅读全文