lstm模型pytorch
时间: 2024-09-23 09:00:16 浏览: 39
基于Pytorch框架的CNN-LSTM模型在CWRU轴承故障诊断的应用
5星 · 资源好评率100%
LSTM(长短期记忆网络)是一种递归神经网络,特别适合处理序列数据,如文本、语音和时间序列预测。在PyTorch中,你可以使用`nn.LSTM`模块来构建LSTM模型。以下是创建一个基本LSTM模型的基本步骤:
1. 导入所需的库:
```python
import torch
import torch.nn as nn
```
2. 定义LSTM层结构:
```python
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, batch_first=True):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
def forward(self, inputs):
# inputs: (batch_size, seq_len, input_size)
out, _ = self.lstm(inputs) # _ is for the cell state, not returned by default
return out[:, -1, :] # 返回最后一个时间步的隐藏状态
# 示例:假设输入每个样本有10个特征,每样本有20个时间步,LSTM隐藏层大小为64,两层
model = LSTMModel(10, 64, 2)
```
3. 初始化模型参数并训练:
```python
input = torch.randn(10, 20, 10) # 假设10个样本,每个样本20个时间步,10个特征
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss() # 如果你的任务是分类
for epoch in range(num_epochs):
output = model(input)
loss = loss_fn(output, target) # 把实际标签填入target
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
阅读全文