imdb双向lstm paddle2.3实现
时间: 2023-11-27 14:54:15 浏览: 78
以下是使用PaddlePaddle 2.3实现IMDb数据集上的双向LSTM模型的示例代码:
```python
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
from paddle.io import DataLoader
from paddle.dataset.imdb import load_imdb
# 超参数
batch_size = 128
embedding_size = 128
hidden_size = 128
num_layers = 2
num_classes = 2
learning_rate = 0.001
num_epochs = 5
# 加载IMDb数据集
train_data, test_data = load_imdb()
# 定义数据加载器
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
# 定义模型
class BiLSTM(nn.Layer):
def __init__(self, vocab_size, embedding_size, hidden_size, num_layers, num_classes):
super(BiLSTM, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, direction='bidirectional')
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.fc(x[:, -1, :])
return x
# 实例化模型并定义损失函数和优化器
model = BiLSTM(len(train_data.vocab), embedding_size, hidden_size, num_layers, num_classes)
loss_fn = nn.CrossEntropyLoss()
optimizer = opt.Adam(learning_rate=learning_rate, parameters=model.parameters())
# 训练模型
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
optimizer.clear_grad()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(train_loader), loss.numpy()[0]))
# 在测试集上评估模型
with paddle.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = paddle.max(outputs, 1)
total += labels.shape[0]
correct += (predicted.numpy() == labels.numpy()).sum().item()
print('Test Accuracy: {:.2f}%'.format(correct / total * 100))
```
这个双向LSTM模型使用了PaddlePaddle中的`nn.Embedding`、`nn.LSTM`和`nn.Linear`模块,分别用来进行词嵌入、双向LSTM计算和全连接层映射。在训练过程中,我们使用了交叉熵损失函数和Adam优化器。在测试阶段,我们使用了`paddle.no_grad()`上下文管理器来禁止梯度计算,从而加速推理速度。
阅读全文