def flops(self, N, T, temporal=False): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) # flops += N * self.dim * 3 * self.dim if temporal: flops += self.dim * N * T * T * 2 else: # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N * T # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) * T # x = self.proj(x) # flops += N * self.dim * self.dim return flops
时间: 2023-06-24 07:09:22 浏览: 82
Cal-FLOPs-for-PLM:在NLP中计算预训练模型的FLOP
这段代码计算了一个窗口内的FLOPS(浮点操作数),其中N为令牌长度,T为时间步数。如果temporal为True,则计算的是temporal self-attention的FLOPS,否则计算的是普通的self-attention的FLOPS。具体来说,计算FLOPS的过程包括以下几个步骤:
1. 计算qkv,其中q、k、v均为大小为N×dim的矩阵,dim为模型的隐藏层维度。这一部分的FLOPS为N×dim×3×dim。
2. 如果是temporal self-attention,计算attn,其中attn为大小为N×T×T的矩阵,表示每个时间步之间的attention。这一部分的FLOPS为2×dim×N×T×T。
3. 如果是普通的self-attention,计算attn,其中attn为大小为num_heads×N×T×T的矩阵,表示每个头在每个时间步的attention。这一部分的FLOPS为num_heads×N×(dim/num_heads)×N×T。
4. 计算x,其中x为大小为num_heads×N×T×(dim/num_heads)的矩阵,表示每个头在每个时间步的输出。这一部分的FLOPS为num_heads×N×N×(dim/num_heads)×T。
5. 计算proj,其中proj为大小为N×dim的矩阵,表示self-attention的输出。这一部分的FLOPS为N×dim×dim。
最终,将所有步骤的FLOPS相加,即得到一个窗口内的总FLOPS。
阅读全文