lstm文本分类pytorch
时间: 2023-10-01 22:12:06 浏览: 168
pytorch实现用CNN和LSTM对文本进行分类方式
5星 · 资源好评率100%
您好!要实现LSTM文本分类,您可以使用PyTorch库。下面是一个简单的示例代码:
首先,导入所需的库和模块:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
```
接下来,定义一个LSTM模型类:
```python
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
embedded = self.embedding(x)
output, (hidden, cell) = self.lstm(embedded)
hidden = hidden[-1]
return self.fc(hidden)
```
然后,加载数据集并进行预处理:
```python
train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')
train_iter = [tokenizer(item[1]) for item in train_iter]
vocab = build_vocab_from_iterator(train_iter)
vocab.set_default_index(vocab["<unk>"])
vocab_size = len(vocab)
train_dataset = [[vocab[token] for token in tokens] for tokens in train_iter]
```
接下来,创建一个自定义的collate函数用于将数据打包成batch:
```python
def collate(batch):
labels = torch.tensor([entry[0] for entry in batch])
text = [entry[1:] for entry in batch]
text = nn.utils.rnn.pad_sequence(text, padding_value=vocab["<pad>"], batch_first=True)
return labels, text
```
然后,设置模型的超参数并创建数据加载器:
```python
embedding_dim = 100
hidden_dim = 256
output_dim = 4
num_layers = 2
batch_size = 64
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
```
接下来,实例化模型并定义损失函数和优化器:
```python
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, output_dim, num_layers)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
```
然后,进行训练和评估:
```python
for epoch in range(epochs):
model.train()
for labels, text in train_dataloader:
labels = labels.to(device)
text = text.to(device)
optimizer.zero_grad()
outputs = model(text)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
model.eval()
# 在验证集上评估模型性能
```
这是一个基本的LSTM文本分类的示例,您可以根据您的数据集和任务需求进行适当的修改和调整。希望对您有帮助!如果有任何问题,请随时提问。
阅读全文