transformer神经网络python代码实现
时间: 2024-09-05 08:00:57 浏览: 46
Transformer神经网络是一种基于自注意力机制的模型,它在自然语言处理领域取得了重大成功,特别是在机器翻译任务中。Transformer模型由一系列编码器(Encoder)层和解码器(Decoder)层组成,每个层都包含自注意力(Self-Attention)和前馈全连接层。下面是一个简化的Transformer模型的Python代码示例,使用了PyTorch库进行实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# Einsum does matrix multiplication for query*keys for each training example
# with every other training example, don't be confused by einsum
# it's just a way to do matrix multiplication with the last two dimensions
# and broadcasted over the batch size and heads
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
# Add skip connection, run through normalization and finally dropout
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
# 假设 embed_size = 256, heads = 8, dropout = 0.1, forward_expansion = 4
model = TransformerBlock(embed_size=256, heads=8, dropout=0.1, forward_expansion=4)
# 示例输入数据
values = torch.rand((64, 50, 256)) # [batch_size, seq_length, embed_size]
keys = torch.rand((64, 50, 256)) # [batch_size, seq_length, embed_size]
query = torch.rand((64, 50, 256)) # [batch_size, seq_length, embed_size]
mask = None # 这里可以是具体的掩码数据,用于处理padding
out = model(values, keys, query, mask)
print(out.shape) # 输出的形状应与query的形状相同
```
请注意,上述代码是一个非常简化的Transformer模型实现,真正的Transformer模型会包含编码器和解码器层,以及更多的细节和优化。为了完整实现一个Transformer模型,您需要定义编码器和解码器层的堆叠,并添加位置编码(Positional Encoding)等组件。
阅读全文