transformer代码
时间: 2023-09-03 20:09:23 浏览: 76
以下是使用PyTorch实现Transformer模型的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, n_layers, n_heads, pf_dim, dropout, device):
super().__init__()
self.tok_embedding = nn.Embedding(input_dim, hidden_dim)
self.pos_embedding = PositionalEncoding(hidden_dim, dropout)
self.layers = nn.ModuleList([TransformerBlock(hidden_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
self.fc = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.device = device
def forward(self, src):
src = src.transpose(0, 1)
src_mask = self.make_src_mask(src)
src = self.dropout(self.pos_embedding(self.tok_embedding(src)))
for layer in self.layers:
src = layer(src, src_mask)
src = src.transpose(0, 1)
output = self.fc(src)
return output
def make_src_mask(self, src):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
src_mask = src_mask.to(self.device)
return src_mask
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim, n_heads, pf_dim, dropout, device):
super().__init__()
self.self_attn_layer_norm = nn.LayerNorm(hidden_dim)
self.ff_layer_norm = nn.LayerNorm(hidden_dim)
self.self_attention = MultiHeadAttention(hidden_dim, n_heads, dropout, device)
self.positionwise_feedforward = PositionwiseFeedforward(hidden_dim, pf_dim, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask):
# self attention
_src, _ = self.self_attention(src, src, src, src_mask)
src = self.self_attn_layer_norm(src + self.dropout(_src))
# positionwise feedforward
_src = self.positionwise_feedforward(src)
src = self.ff_layer_norm(src + self.dropout(_src))
return src
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, n_heads, dropout, device):
super().__init__()
assert hidden_dim % n_heads == 0
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
self.fc_q = nn.Linear(hidden_dim, hidden_dim)
self.fc_k = nn.Linear(hidden_dim, hidden_dim)
self.fc_v = nn.Linear(hidden_dim, hidden_dim)
self.fc_o = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e10)
attention = self.dropout(F.softmax(scores, dim=-1))
x = torch.matmul(attention, V)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.head_dim)
x = self.fc_o(x)
return x, attention
class PositionwiseFeedforward(nn.Module):
def __init__(self, hidden_dim, pf_dim, dropout):
super().__init__()
self.fc_1 = nn.Linear(hidden_dim, pf_dim)
self.fc_2 = nn.Linear(pf_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc_1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc_2(x)
x = self.dropout(x)
return x
```
这个代码实现了一个Transformer模型,包含了Multi-Head Attention、Positionwise Feedforward等模块,并且使用了位置编码。用户可以根据自己的需求修改其中的参数和结构。
阅读全文