transformer pytorch封装
时间: 2025-01-08 10:26:16 浏览: 0
### 如何使用PyTorch实现或封装Transformer模型
#### 创建Transformer类
为了创建一个基于`nn.Module`的自定义Transformer模型,可以继承`torch.nn.Module`并重写其构造函数和前向传播方法。在构造函数中初始化各个组件,在前向传播方法中定义数据流。
```python
import torch
from torch import nn, Tensor
import math
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
"""
Arguments:
src: Tensor, shape ``[seq_len, batch_size]``
src_mask: Tensor, shape ``[seq_len, seq_len]``
Returns:
output Tensor of shape ``[seq_len, batch_size, ntoken]``
"""
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
```
此代码片段展示了如何构建一个简单的Transformer模型[^1]。注意这里引入了一个位置编码器(`PositionalEncoding`)用于给输入加入位置信息,因为原始Transformer并没有考虑词序的重要性。
#### 定义辅助功能——位置编码
由于Transformers不依赖于RNNs或其他具有内在顺序性的单元,因此需要显式地为每个标记添加相对或绝对的位置信息:
```python
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
```
这段实现了正弦波形式的位置编码方案[^4],它被广泛应用于各种NLP任务当中。
阅读全文