写一个LSTM分类模型
时间: 2023-11-07 21:58:19 浏览: 78
好的,下面是一个简单的LSTM分类模型的代码示例:
```python
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(LSTMClassifier, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=2, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
embedded = self.embedding(x)
output, _ = self.lstm(embedded)
output = output[:, -1, :]
logits = self.fc(output)
return logits
```
这个模型接受一个输入序列 x,通过嵌入层将其转换为嵌入向量,然后输入到LSTM层中。最后将LSTM层的输出通过全连接层进行分类预测。
阅读全文