能否生成transformer的torch代码
时间: 2023-08-13 19:53:32 浏览: 85
当然可以!以下是一个简单的示例,用于生成Transformer模型的PyTorch代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask=None):
src2 = self.self_attn(src, src, src, attn_mask=src_mask)[0]
src = src + self.dropout(src2)
src = self.norm1(src)
src2 = self.linear2(F.relu(self.linear1(src)))
src = src + self.dropout(src2)
src = self.norm2(src)
return src
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
def forward(self, src, src_mask=None):
for layer in self.layers:
src = layer(src, src_mask)
return src
class Transformer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward, num_layers, num_classes):
super(Transformer, self).__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(d_model, num_classes)
def forward(self, src, src_mask=None):
src = self.encoder(src, src_mask)
src = src.mean(dim=1) # Global average pooling
logits = self.fc(src)
return logits
# 示例用法
d_model = 512
nhead = 8
dim_feedforward = 2048
num_layers = 6
num_classes = 10
transformer = Transformer(d_model, nhead, dim_feedforward, num_layers, num_classes)
src = torch.randn(16, 10, d_model) # 输入形状为 (batch_size, seq_len, d_model)
output = transformer(src)
print(output.shape) # 输出形状为 (batch_size, num_classes)
```
请注意,这只是一个简单的示例,实际上Transformer模型的实现可能更复杂,具体取决于您的任务和需求。你可以根据自己的需要进行修改和扩展。
阅读全文