给出rnn实现文本分类的代码
时间: 2023-08-31 07:12:48 浏览: 174
以下是一个基于RNN的文本分类的示例代码,使用PyTorch实现:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# 定义模型
class RNNClassifier(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_classes):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_size, sparse=True)
self.rnn = nn.RNN(embed_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, text):
embedded = self.embedding(text)
output, hidden = self.rnn(embedded.unsqueeze(0))
return self.fc(hidden.squeeze(0))
# 预处理数据
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
train_iter = AG_NEWS(split='train', vocab=vocab)
num_classes = len(train_iter.get_labels())
# 初始化模型和优化器
model = RNNClassifier(len(vocab), 64, 128, num_classes)
optimizer = optim.Adam(model.parameters())
# 训练模型
EPOCHS = 10
for epoch in range(EPOCHS):
for text, label in train_iter:
model.zero_grad()
output = model(text)
loss = nn.functional.cross_entropy(output.unsqueeze(0), label.unsqueeze(0))
loss.backward()
optimizer.step()
# 测试模型
test_iter = AG_NEWS(split='test', vocab=vocab)
correct = 0
total = 0
with torch.no_grad():
for text, label in test_iter:
output = model(text)
predicted = torch.argmax(output).item()
total += 1
if predicted == label.item():
correct += 1
print(f'Accuracy: {correct/total}')
```
在此示例代码中,我们首先定义了一个名为`RNNClassifier`的模型,该模型使用`EmbeddingBag`层将单词嵌入向量,然后使用`RNN`层处理这些向量,并使用线性层将输出映射到类别标签上。我们还使用了PyTorch内置的交叉熵损失函数来计算模型的损失。
我们使用`AG_NEWS`数据集进行训练和测试。该数据集包含120,000个新闻文本和4个类别标签(World、Sports、Business、Sci/Tech)。我们使用`get_tokenizer`函数和`build_vocab_from_iterator`函数来预处理数据,然后使用`AG_NEWS`迭代器加载数据。
在训练过程中,我们使用Adam优化器来最小化损失。在测试过程中,我们使用测试集评估模型的准确性。
阅读全文