请你学习pytorch相关应用教程,给出一个以pytorch实现的具有LSTM循环神经网络(LSTM-RNN) 的例子, 解释代码和运行结果。
时间: 2024-12-28 11:37:41 浏览: 5
在PyTorch库中,LSTM (长短期记忆) 网络是一种常用的循环神经网络,特别适合处理序列数据,如文本、音频或时间序列预测。下面是一个简单的例子,我们将构建一个LSTM用于文本生成:
```python
import torch
import torch.nn as nn
from torchtext.data import Field, BucketIterator
# 初始化超参数
input_dim = 50 # 假设输入特征维度为50
hidden_dim = 256
num_layers = 2
batch_size = 64
seq_length = 10 # 序列长度
output_dim = 10 # 输出类别数(例如,单词表大小)
# 定义词汇场和创建数据加载器
TEXT = Field(tokenize='spacy', init_token='<sos>', eos_token='<eos>', lower=True)
train_data, valid_data, test_data = TEXT.ngrams('your_text_dataset.txt', n=seq_length)
TEXT.build_vocab(train_data, min_freq=2)
# 创建模型
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(num_layers, x.size(0), self.hidden_dim).to(device)
c0 = torch.zeros(num_layers, x.size(0), self.hidden_dim).to(device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_dim, hidden_dim, output_dim, num_layers).to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练过程
for epoch in range(10): # 假设我们训练10轮
for batch in train_loader:
inputs, targets = batch.text.to(device), batch.label.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
# 使用模型预测
sample_sequence = TEXT.ngrams("示例文本", seq_length)[0][0].numpy().tolist()
sequence = [TEXT.vocab.stoi[sample] for s in sample_sequence]
with torch.no_grad():
for i in range(100):
x = torch.tensor([sequence[-seq_length:]], device=device)
pred = model(x)
next_word_idx = pred.argmax(1).item()
sequence.append(next_word_idx)
sequence = sequence[-seq_length:]
print("生成的文本:", ''.join(TEXT.itos[word_idx] for word_idx in sequence))
阅读全文