pytorch Transformer
时间: 2025-01-08 13:49:18 浏览: 4
### 使用PyTorch实现Transformer模型
#### 宏观视角下的Transformer模型
在宏观层面,Transformer可以被视作一个执行序列到序列转换的任务模型[^1]。例如,在自然语言处理领域中的机器翻译任务里,输入一段源语言文字,经过一系列复杂的计算过程之后,能够得到目标语言的文字输出。
#### Transformer模型的主要组成部分
为了构建这样一个强大的工具,通常会创建一个名为`TransformerModel`的类来封装整个网络结构,并让这个类继承自`nn.Module`以便利用PyTorch框架提供的功能[^2]。具体来说,完整的Transformer架构包含了以下几个核心模块:
- **位置编码器(Positional Encoding)**: 由于原始版本的Transformer并没有像循环神经网络那样具备内在的时间/顺序感,因此引入了位置编码以帮助模型理解词序的重要性[^3]。
- **特征编码器(Feature Encoder / Encoder Layer Stack)**: 这一部分负责接收并处理来自输入端的数据流,通过多层堆叠的方式逐步提取高层次语义信息。
- **特征解码器(Feature Decoder / Decoder Layer Stack)**: 解码器接受来自编码器的信息以及其他可能存在的上下文提示作为输入,最终生成期望的目标序列。
- **线性变换与Softmax函数组成的输出层**:此部分用于将解码后的高维向量映射回词汇表空间,从而完成具体的预测工作。
下面给出了一段简化版的代码片段展示如何基于上述描述快速搭建起一个基本可用的Transformer实例:
```python
import torch.nn as nn
from torch import Tensor, LongTensor
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src: LongTensor, src_mask: Tensor) -> Tensor:
"""
Args:
src: Tensor, shape [seq_len, batch_size]
src_mask: Tensor, shape [seq_len, seq_len]
Returns:
output Tensor of shape [seq_len, batch_size, ntoken]
"""
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
```
这段代码展示了怎样定义一个简单的Transformer模型及其辅助的位置编码组件。需要注意的是这只是一个基础模板,实际应用时还需要考虑更多细节优化以及针对特定任务调整参数配置等问题。
阅读全文