LSTM分类代码pytorch
时间: 2023-06-18 13:05:39 浏览: 116
以下是一个简单的LSTM分类代码,使用PyTorch:
```python
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTMClassifier, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, input):
lstm_out, _ = self.lstm(input.view(len(input), 1, -1))
linear_out = self.linear(lstm_out.view(len(input), -1))
return linear_out[-1]
# 定义模型参数
input_size = 10 # 输入特征维度
hidden_size = 20 # LSTM隐藏层维度
output_size = 2 # 输出类别数
# 初始化模型
lstm_model = LSTMClassifier(input_size, hidden_size, output_size)
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.01)
# 训练模型
for epoch in range(100):
for input_data, label in training_data:
lstm_model.zero_grad()
lstm_output = lstm_model(input_data)
loss = loss_fn(lstm_output, label)
loss.backward()
optimizer.step()
# 评估模型
lstm_model.eval()
correct = 0
total = 0
with torch.no_grad():
for input_data, label in test_data:
lstm_output = lstm_model(input_data)
_, predicted = torch.max(lstm_output.data, 1)
total += label.size(0)
correct += (predicted == label).sum().item()
accuracy = 100 * correct / total
print('Accuracy of the LSTM classifier on the test set: {:.2f}%'.format(accuracy))
```
在这个示例中,我们使用一个LSTM模型对输入数据进行分类,并通过交叉熵损失函数进行训练。我们使用Adam优化器进行梯度下降,并在训练期间迭代地更新模型参数。在训练完成后,我们评估模型在测试集上的准确度,并输出结果。
阅读全文