用pytorch搭建一个使用了attention的seq2seq模型
时间: 2024-11-04 22:09:32 浏览: 42
在PyTorch中,构建一个包含注意力机制的序列到序列(Seq2Seq)模型通常用于自然语言处理任务,如机器翻译或文本摘要。这里是一个基本的步骤概述:
1. **导入库**:
```python
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
```
2. **定义Encoder**:
使用LSTM、GRU或Transformer Encoder,它接收输入序列并生成隐藏状态。通常会有一个注意力模块(如自注意力或多头注意力)来结合编码器的不同部分。
```python
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, enc_hid_dim, n_layers, dropout):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.LSTM(emb_dim, enc_hid_dim, n_layers, dropout=dropout)
self.attn = nn.Linear(enc_hid_dim * 2, enc_hid_dim)
def forward(self, src, src_lengths):
embedded = self.embedding(src)
packed_embedded = pack_padded_sequence(embedded, src_lengths.data.tolist(), batch_first=True)
packed_output, (hidden, cell) = self.rnn(packed_embedded)
output, _ = pad_packed_sequence(packed_output, batch_first=True)
context = self._get_attn_weights(output, hidden)
return context, hidden
def _get_attn_weights(self, encoder_outputs, encoder_hidden):
energy = torch.tanh(self.attn(torch.cat((encoder_outputs, encoder_hidden), dim=2)))
attn_weights = nn.functional.softmax(energy, dim=1).unsqueeze(1)
return attn_weights
```
3. **定义Decoder**:
Decoder一般也是一个RNN,但它可以访问编码器的状态以及来自注意力层的上下文向量。
```python
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, dec_hid_dim, n_layers, dropout, attention):
super(Decoder, self).__init__()
self.output_dim = output_dim
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.LSTM(emb_dim + enc_hid_dim, dec_hid_dim, n_layers, dropout=dropout)
self.fc_out = nn.Linear(dec_hid_dim, output_dim)
self.attention = attention
def forward(self, input, hidden, context):
input = self.embedding(input)
rnn_input = torch.cat((input, context.unsqueeze(1)), dim=2)
output, hidden = self.rnn(rnn_input, hidden)
prediction = self.fc_out(output)
attention_weighted_context = self.attention(context, output)
return prediction, hidden, attention_weighted_context
```
4. **完整模型**:
将Encoder和Decoder组合在一起,并添加训练循环。
```python
model = Seq2Seq(encoder, decoder)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=model.ignore_id)
```
5. **训练和预测**:
遍历数据集,使用teacher forcing(给下一个时间步提供正确的词作为输入)进行训练,然后在解码阶段利用注意力机制进行预测。
阅读全文
相关推荐


















