nn.LSTM的源码是什么
时间: 2024-01-07 21:14:57 浏览: 84
CNN+LSTM EGG项目,cnn-lstm,Python源码.zip
5星 · 资源好评率100%
nn.LSTM是PyTorch中的一个预定义模型,其源代码实现可以在PyTorch的GitHub仓库中找到。以下是nn.LSTM的源代码实现:
```python
class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)
def check_forward_args(self, input, hidden, batch_sizes):
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
if tuple(hx.size()) != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
if self.mode == 'LSTM':
check_hidden_size(hidden[0], expected_hidden_size, 'Expected hidden[0] size {}, got {}')
check_hidden_size(hidden[1], expected_hidden_size, 'Expected hidden[1] size {}, got {}')
else:
check_hidden_size(hidden, expected_hidden_size)
@weak_script_method
def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
if is_packed:
input, batch_sizes = input.data, input.batch_sizes
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = input.new_zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
requires_grad=False)
hx = self._init_hidden(hx)
self.check_forward_args(input, hx, batch_sizes)
func = self._backend.RNN_backward if self.mode == 'RNN' else self._backend.LSTM_backward
output, hidden = func(
input, hx, self.all_weights, self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
if is_packed:
output = PackedSequence(output, batch_sizes)
return output, hidden
```
以上代码实现了LSTM的前向传播过程。其中,对于输入数据进行Packed操作的部分用于支持变长序列输入的情况。
阅读全文