attn = attn + torch.bernoulli(m_r) * -1e12
时间: 2024-05-28 15:12:15 浏览: 17
这段代码是对注意力矩阵进行随机遮盖的操作,其中torch.bernoulli(m_r)生成的随机矩阵用于控制哪些位置被遮盖,-1e12表示遮盖后的注意力值非常小,相当于被完全忽略,从而达到遮盖的效果。具体来说,Bernoulli分布是一种二项分布,它生成的随机数只有两个取值0或1,其中0的概率为1-p,1的概率为p,而这里的m_r就是控制1和0的概率的参数。当随机矩阵中某个位置为1时,对应的注意力矩阵中的值会被置为-1e12,从而实现遮盖的效果。
相关问题
def flops(self, N, T, temporal=False): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) # flops += N * self.dim * 3 * self.dim if temporal: flops += self.dim * N * T * T * 2 else: # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N * T # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) * T # x = self.proj(x) # flops += N * self.dim * self.dim return flops
这段代码计算了一个窗口(window)的浮点操作(flops)。其中,输入的token长度为N,每个token的向量维度为self.dim,有self.num_heads个注意力头(attention head),每个注意力头的向量维度为self.dim // self.num_heads。如果参数temporal为True,则还需要计算时间维度T的影响。具体来说,计算过程分为两个部分:
1. 注意力机制(attention mechanism)部分,包括计算query、key、value向量(self.qkv(x)),计算注意力矩阵(q @ k.transpose(-2, -1)),以及计算加权后的value向量(attn @ v)。这部分的浮点操作数量为:
(self.num_heads * N * (self.dim // self.num_heads) * N * T) + (self.num_heads * N * N * (self.dim // self.num_heads) * T)
2. 线性变换(linear transformation)部分,包括用self.proj将加权后的value向量映射到输出向量空间。这部分的浮点操作数量为:
(N * self.dim * self.dim)
最终,两部分的浮点操作数量相加即为一个窗口的总浮点操作数量。
def flops(self): flops = 0 H, W = self.input_resolution T = self.num_frames ## just count the FLOPs of q@k and attn@v # norm1 # flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size if self.t_attn: flops += nW * self.attn.flops(self.window_size * self.window_size, T, temporal=True) flops += nW * self.attn.flops(self.window_size * self.window_size, T, temporal=False) # mlp # flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 # flops += self.dim * H * W return flops
这段代码实现了一个计算模型 FLOPs 的方法。FLOPs 是衡量模型计算量的指标,代表模型所需的浮点运算次数。在这个方法中,首先初始化 FLOPs 为0,然后根据输入分辨率、帧数、窗口大小以及是否有时间注意力,计算出注意力机制的 FLOPs 和 MLP 的 FLOPs,并返回它们的总和。此外,注释掉的代码是计算归一化操作的 FLOPs,但是在这个方法中并没有被计算。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)