多头注意力机制pytorch代码
时间: 2023-12-22 21:29:25 浏览: 139
多头注意力机制是一种在自注意力机制的基础上进行扩展的模型,它能够同时关注输入序列的不同部分并提取更丰富的特征表示。在PyTorch中,可以使用`torch.nn.MultiheadAttention`类来实现多头注意力机制。
下面是一个使用多头注意力机制的PyTorch代码示例:
```python
import torch
import torch.nn as nn
# 定义输入
batch_size = 2
seq_len = 3
embed_dim = 4
num_heads = 2
# 创建多头注意力机制实例
attention = nn.MultiheadAttention(embed_dim, num_heads)
# 创建输入张量
input = torch.randn(seq_len, batch_size, embed_dim)
# 进行多头注意力计算
output, _ = attention(input, input, input)
# 输出结果
print("Output shape:", output.shape)
print("Output tensor:", output)
```
在上述代码中,我们首先导入了`torch`和`torch.nn`模块。然后,我们定义了输入的批次大小(`batch_size`)、序列长度(`seq_len`)、嵌入维度(`embed_dim`)和头数(`num_heads`)。接下来,我们创建了一个`MultiheadAttention`实例,并传入嵌入维度和头数作为参数。然后,我们创建了一个输入张量`input`,其形状为`(seq_len, batch_size, embed_dim)`。最后,我们调用多头注意力机制的`forward`方法,并将输入张量作为参数传入。输出结果是一个形状为`(seq_len, batch_size, embed_dim)`的张量,表示多头注意力机制的输出。
阅读全文