persformer代码讲解
时间: 2023-07-09 15:38:18 浏览: 125
[ECCV2022口头]PerspectiveTransformeron3DLaneDetection_Python_C.zip
Performer是一种极速的自注意力机制,它通过使用低秩近似技术来减少计算量。下面是一个简单的Performer实现的代码示例:
```python
import torch
from torch import nn
class PerformerAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: torch.einsum('b n d, b n e -> b d e', t.reshape(b, n, h, -1), self.scale), qkv)
dots = torch.einsum('b i d, b j d -> b i j', q, k)
attn = dots.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = out.reshape(b, n, -1)
out = self.to_out(out)
out = self.dropout(out)
return out
```
在这个实现中,我们首先定义了一个PerformAttention类,它继承自nn.Module。在__init__函数中,我们定义了一些参数,包括输入张量的维度dim,头的数量heads,每个头的维度dim_head以及dropout率。然后,我们使用一个线性层self.to_qkv来将输入x转换为查询、键和值。我们还定义了一个线性层self.to_out来将输出转换回原始维度。我们对查询和键进行了缩放,然后计算了点积得分dots,并对其进行softmax操作以获得注意力分布。最后,我们将注意力分布乘以值,并使用self.to_out将输出转换回原始维度。我们还可以选择在输出上应用dropout。
这个实现是一个简单的版本,但是如果需要更高效的实现,可以使用一些低秩近似技术,如Fast Attention Via Positive Orthogonal Random Features (Lin et al., 2021)。
阅读全文