请给我一段Pytorch的LSTM模型对自然语言处理的实际代码
时间: 2024-03-05 12:53:37 浏览: 135
好的,以下是一个简单的Pytorch LSTM模型的自然语言处理代码示例:
```python
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, input):
batch_size = input.size(1)
hidden = self.init_hidden(batch_size)
lstm_out, hidden = self.lstm(input, hidden)
output = self.fc(lstm_out[-1])
return output
def init_hidden(self, batch_size):
return (torch.zeros(1, batch_size, self.hidden_size),
torch.zeros(1, batch_size, self.hidden_size))
# Example usage
input_size = 100
hidden_size = 256
output_size = 10
model = LSTM(input_size, hidden_size, output_size)
input = torch.randn(5, 3, input_size) # sequence length, batch size, input size
output = model(input)
print(output.size()) # should be (3, 10) - batch size, output size
```
在这个例子中,我们定义了一个LSTM模型,它的输入大小为100,隐藏大小为256,输出大小为10。我们在输入上运行LSTM,并将最后一个输出传递给一个全连接层,以产生最终的输出。在这个例子中,我们使用了PyTorch中的nn.LSTM层,它自动处理序列长度和批量大小,因此我们可以一次将多个序列传递给模型。
阅读全文