写个 transformer
时间: 2024-01-11 13:22:29 浏览: 36
以下是一个简单的Transformer模型的示例代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, num_layers):
super(Transformer, self).__init__()
self.embedding = nn.Linear(input_dim, hidden_dim)
self.attention = MultiHeadAttention(hidden_dim, num_heads)
self.feed_forward = FeedForward(hidden_dim)
self.output = nn.Linear(hidden_dim, input_dim)
self.num_layers = num_layers
def forward(self, x):
x = self.embedding(x)
for _ in range(self.num_layers):
x = self.attention(x)
x = self.feed_forward(x)
x = self.output(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.query = nn.Linear(hidden_dim, hidden_dim)
self.key = nn.Linear(hidden_dim, hidden_dim)
self.value = nn.Linear(hidden_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
batch_size, seq_len, _ = x.size()
q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(self.head_dim)
attention_weights = F.softmax(scores, dim=-1)
x = torch.matmul(attention_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
x = self.fc(x)
return x
class FeedForward(nn.Module):
def __init__(self, hidden_dim):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建一个Transformer模型实例
input_dim = 512
hidden_dim = 256
num_heads = 8
num_layers = 4
transformer = Transformer(input_dim, hidden_dim, num_heads, num_layers)
# 输入数据
x = torch.randn(10, 20, input_dim) # batch_size=10, sequence_length=20
# 前向传播
output = transformer(x)
print(output.shape) # 输出:torch.Size([10, 20, 512])
```
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![-](https://csdnimg.cn/download_wenku/file_type_lunwen.png)
![-](https://csdnimg.cn/download_wenku/file_type_lunwen.png)
![-](https://csdnimg.cn/download_wenku/file_type_lunwen.png)
![-](https://csdnimg.cn/download_wenku/file_type_lunwen.png)
![-](https://csdnimg.cn/download_wenku/file_type_lunwen.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)