没有合适的资源?快使用搜索试试~ 我知道了~
视觉transformers中的令牌池化用于图像分类
Dmitrii Marin†, Jen-Hao Rick Chang⋆, Anurag Ranjan⋆, Anish Prabhu⋆Mohammad Rastegari⋆, Oncel Tuzel⋆120视觉transformers中的令牌池化用于图像分类0† Waterloo大学,� Apple0jenhao chang@apple.com0摘要0池化通常用于改善卷积网络的计算精度权衡。通过在图像网格上聚合相邻特征值,池化层可以降低特征图的维度而保持准确性。然而,在标准视觉transformers中,令牌是单独处理的,不一定位于规则的网格上。因此,利用为图像网格设计的池化方法(例如平均池化)可能对transformers来说并不是最优的,正如我们的实验所示。在本文中,我们提出了TokenPooling来对视觉transformers中的令牌集进行降采样。我们采用了一种新的视角——不再假设令牌形成一个规则的网格,而是将它们视为隐含连续信号的离散(和不规则)样本。给定目标令牌数量,TokenPooling找到最能近似底层连续信号的令牌集。我们对标准transformer架构(ViT/DeiT)以及使用ImageNet-1k进行的图像分类问题对所提出的方法进行了严格评估。我们的实验结果表明,TokenPooling显著改善了计算精度权衡,而无需对架构进行任何进一步的修改。TokenPooling使得DeiT-Ti在使用42%更少计算的情况下达到相同的top-1准确率。01. 引言0视觉transformers [10, 14, 21, 43,55]已经证明了与卷积神经网络(CNN)相当的强大结果。然而,高计算成本限制了它们在资源受限、实时或低功耗场景中的使用。在本文中,我们旨在通过对令牌进行降采样,即在transformer层之间删除冗余的令牌,来提高标准视觉transformers在图像分类中的计算效率。降采样是CNN中用于提高计算效率的主要技术之一。在CNN中,特征被明确设计为位于规则的xy网格上,这是由于重复使用具有局部感受野的卷积操作。0(a) 准确性 vs. 计算0(b)通过令牌表示的聚类分析进行令牌池化。0图1:我们提出了一种新颖的令牌降采样方法TokenPooling,适用于像ViT [10]和DeiT[43]这样的标准视觉transformers。 (a)在DeiT中插入TokenPooling层显著改善了在ImageNet-1k上的准确性和计算之间的权衡。 (b) TokenPooling将令牌视为连续信号的离散样本。这激发我们使用聚类分析来自动聚合相邻令牌信息。我们展示了DeiT-S的第6层的输入图像和令牌聚类。0字段。xy-grid结构允许降采样技术(如最大/平均池化)聚合网格上的相邻特征值并减少网格维度。然而,在标准transformers中并没有保持xy-grid结构。在transformers中,令牌的xy位置仅通过位置编码在第一层(或几层)中作为特征值的提示。令牌通过多层感知机(MLP)分别进行处理。130(a) 令牌池化 (b) Transformer块0图2:(a)我们在每个transformer块之后插入TokenPooling层,而不修改任何架构。这使得我们能够研究仅由降采样层引起的效果。(b)显示了一个transformer块。0头部注意力,并且每个头部都由多层感知机(MLP)单独处理。换句话说,在像ViT/DeiT [10,43]这样的标准transformers中,除了第一层以图像块作为输入之外,令牌不位于规则的网格上。我们需要一种专门的降采样技术。我们提出了一种名为TokenPooling的降采样操作符,用于标准transformers,旨在提高图像分类问题的计算效率。受非均匀采样和图像压缩的启发[2, 23, 30,44],我们将令牌视为连续信号的离散和不规则样本,并将令牌降采样形式化为一个最小化从降采样的令牌中重建底层连续信号的重构误差的优化问题。令人惊讶的是,我们证明了这个优化问题的解可以通过聚类分析(如K-Means和K-Medoids)轻松找到(见图1b)。我们对所提出的方法和各种先前的降采样技术进行了彻底研究,包括随机/重要性采样令牌、基于分数的令牌修剪[13,34]以及广泛应用于最近的视觉transformers [21, 25,49]中的卷积带步长降采样方法,该方法推广了平均池化。我们专注于研究降采样层的唯一效果——我们直接在每个transformer块之后插入降采样层,而不优化任何其他架构,例如特征维度和层数,如图2a所示。我们的结果表明,Token Pooling优于基线降采样技术(Figure4),并显著改善了像ViT这样的标准视觉transformers的计算精度权衡(Figure 1a)。0贡献。本文做出了以下贡献:01 DeiT与ViT具有相同的架构,但通过改进的训练方法提高了数据效率。因此,我们在DeiT上进行实验。0•我们通过比较计算-精度权衡来对视觉transformers的先前下采样技术进行了广泛研究。0•我们分析了视觉transformer组件的计算成本和先前基于分数的下采样方法的局限性。我们展示了softmax注意力层对标记进行低通滤波,并因此构造了冗余标记。0• 我们提出了一种新颖的标记下采样技术,TokenPooling,在ImageNet-1k上在计算-精度权衡方面取得了显著的改进。02. 相关工作0在本节中,我们介绍视觉transformers,并回顾改进transformers效率的现有方法,包括现有的标记下采样方法。02.1. 视觉transformers0视觉transformers [10, 14, 21, 21, 25, 43]利用了最初由Vaswani等人[45]设计用于自然语言处理(NLP)的transformer架构,并进一步由Radford等人[31]和Devlin等人[9]推广。标准的视觉transformer是由L个transformer块组成的,它接受一组输入标记并返回另一组标记。第一个transformer块的输入标记是表示图像块的特征。在标准的transformer架构[45]中,标记的数量在整个网络中保持不变。为了进行分类,插入一个单独的分类标记来估计各个类别的概率。令深度为l的N个标记的集合为Fl = {fl0, ..., flN},其中fli ∈RM是第i个标记的特征。深度为l的transformer块φ通过多头自注意力(MSA)层和逐点多层感知机(MLP)对Fl进行处理,如图2b所示。令矩阵F ∈ RN × M0将标记Fl的行连接为F l 。然后2,0φ ( F ) = MLP(MSA( F )) ,使得(1)0MSA(F) = [O1, O2, ..., OH]WO,(2)0其中 H 是头的数量,矩阵 W O ∈ R M × M是块的可学习参数,[ , ] 是按列连接,O h ∈ R N × d 是第h 个注意力头的输出,其中 d = M/H:0A h = softmax(Q h K � h √0d ) ∈ R N × N 。 (3)0Keys K h ,queries Q h 和values V h是输入标记的线性投影(QKV投影):0Q h = FW Q h ,K h = FW K h ,V h = FW V h ,(4)02 为了紧凑性,我们省略了层归一化和跳跃连接,参见[45]。140其中 W Q h ∈ R M × d , W K h ∈ R M × d , W V h∈ R M × d是可学习的线性变换。注意,标记特征维度在整个网络中保持不变。此外,输入和输出标记的数量相同,即 |F l +1 | =|F l | = N。通过在transformer块之间插入标记下采样层,我们减少了标记的数量,从而降低了计算成本。由于下采样不可避免地会丢失信息,问题是:我们如何下采样/选择标记,以使模型的性能不会显著降低?02.2. 高效的transformers0与许多机器学习模型类似,通过元参数搜索[15,39]、自动神经架构搜索[11, 38,50]、操纵特征图的输入大小和分辨率[15, 27,54]、剪枝[19]、量化[16]和稀疏化[12]等,可以提高transformers的效率。例如,Dosovitskiy等人[10]和Touvron等人[43]通过改变输入分辨率、头的数量H和特征维度M获得了一系列ViT和DeiT模型。每个模型的计算需求和准确性都不同。接下来,我们将回顾改进transformers效率的技术。02.2.1 高效的自注意力0softmax-attention层(3)的时间复杂度与令牌数量的平方成正比,即O(N^2)。在许多NLP应用中,每个令牌代表一个单词或一个字符,N可能很大,使得注意力成为计算瓶颈[7,32]。虽然许多工作改进了注意力层的时间复杂度,如我们将在第3.1节中看到的,但它们在大多数视觉Transformer中并不是瓶颈。注意力层的时间复杂度可以通过限制注意力视野从而在A_h(3)中引入稀疏性来降低。这可以通过使用图像/文本领域中的令牌之间的空间关系[3, 4, 26, 29, 33,53]或基于令牌值使用局部敏感哈希、排序、压缩、聚类等方法[18, 20, 40, 41,46–48]来实现。先前的工作还提出了具有较低时间复杂度的注意力机制,例如O(N)或O(N log N)[5, 17, 28,40]。请注意,这些方法的目标是降低注意力层的时间复杂度-令牌数量在Transformer块之间保持不变。相反,我们的方法在计算注意力后减少了令牌的数量。因此,我们可以利用这些方法进一步提高Transformer的整体效率。最近,Wu等人[51]提出了一种新的基于注意力的层,该层学习少量的查询向量以从输入特征图中提取信息。Roy等人[35]0聚类查询和键以稀疏化注意力矩阵并加快注意力。类似地,Wu等人[52]提出了中心化Transformer,它将自定义的软K-means展开为一个特殊的可学习的注意力层。其效果类似于PoWER-BERT[13](使用静态策略),在第5节中将进行广泛比较。相比之下,我们的方法没有可学习的参数,并且通过将令牌视为底层连续信号的不规则样本,直接最小化由于令牌下采样而引起的重构误差。02.2.2 现有的令牌下采样方法0网格下采样。网格下采样技术假设令令牌位于一个规则的网格上,通常通过使用它们在图像上的初始位置来排列令牌来实现。规则的网格结构允许使用典型的下采样方法,如最大/平均池化或均匀子采样。例如,Liu等人[21],Heo等人[14]和Wang等人[49]使用步长卷积(即卷积后进行子采样)来对令牌形成的特征图进行下采样;同样,Dai等人[6]使用平均池化。请注意,正如我们在第1节中讨论的那样,标准的Transformer块不保留图像网格结构,并将令牌视为无序集合。为了鼓励令牌形成网格结构,Liu等人[21]修改了Transformer块,并限制了MSA对输入图像网格上的相邻令牌的注意范围。这种架构中的归纳偏差-以计算效率为代价的注意范围交换-使得他们的方法在各种应用中取得了最先进的结果。在本文中,我们重点研究令牌下采样层的唯一效果-保留标准的Transformer块架构,而不像Liu等人[21]那样修改Transformer块。正如我们的分析将展示的那样,在没有对Transformer块进行任何修改的情况下,通过步长卷积(或等效的平均池化)对令牌进行下采样实际上比一个简单的随机选择/删除令牌的算法表现更差,从而验证了我们关于网格结构的假设。0基于分数的令牌下采样。在NLP领域,Goyal等人[13]引入了基于“显著性分数”的PoWER-BERT,该分数定义为所有其他令牌对一个令牌的总注意力。具体而言,第l个Transformer块中所有令牌的显著性分数,s_l ∈R^N,通过以下方式计算0sl =0h =1 Alh�1,(5)0其中Alh是头h的注意权重,定义如式(3)。它们只传递具有sl中最高分数的Kl个标记到下一个变换器块。=1=150下一个变换器块。与我们的设置类似,修剪是在所有块上执行的。PoWER-BERT使用三阶段的训练过程。首先,给定一个基本架构,他们在不修剪的情况下预训练模型。在第二阶段,每个变换器块后插入一个软选择层,并对模型进行少量的微调。一旦学习到,从软选择层计算出每层要保留的标记数Kl。最后,再次对模型进行微调,剪掉不太重要的标记。详细信息请参阅[13]。最近,Rao等人提出了Dynamic-ViT,也使用分数来修剪标记。与PoWER-BERT不同,Rao等人使用具有学习参数的专用子网络来计算重要性分数。该方法需要知识蒸馏、Gumbel-Softmax和直通估计器,以及DeiT的训练和架构。我们将在第3.2节中分析基于分数的方法,并在第5节中与它们进行比较。03. 分析0本节回答了三个问题。首先,我们确定了视觉变换器的计算瓶颈。其次,我们讨论了基于分数的下采样的局限性。最后,我们分析了softmax-attention如何影响变换器中标记的冗余性。这一发现对于设计下采样算法非常重要。03.1. 视觉变换器的计算分析0我们分析了视觉变换器(ViT/DeiT)的时间复杂性和计算成本(以flops为单位)。我们将计算分为四个类别:softmax-attention(3)、QKV投影(4)、O投影(2)和MLP(1)。如表1所示,在所有这些视觉变换器中,主要的计算瓶颈是全连接层,它们占总计算量的80%以上。相比之下,softmax-attention只占不到15%。请注意,我们明确将多头注意力分解为softmax-attention、QKV和O投影,因为它们具有不同的时间复杂性。这种分解揭示了QKV和O投影在多头自注意力计算中占据了大部分计算量。全连接层的时间复杂度为O(LNM^2)。通过对标记进行下采样(即减少N),我们可以提高时间复杂度,而不会显著降低它们的容量(即特征维度M和层数L)。03.2. 基于分数的下采样的局限性0现有的基于分数的标记下采样方法,如PoWER-BERT和Dynamic-ViT,利用评分函数确定要保留(或修剪)的标记。它们保留具有前K个最高分数的标记,并丢弃其余的标记。由于0(a)基于分数的(b)提出的0图3:基于分数的下采样方法[13,34]与提出的方法的比较。在图中,x轴表示标记值(在一个维度上),y轴表示它们的分数。假设要选择四个标记。(a)基于分数的方法选择具有较高分数的标记。由于评分函数是连续的,左侧叶子中的所有标记都将被选择,导致右侧叶子中的冗余和信息损失。(b)提出的方法首先形成四个簇以逼近标记集合,然后选择簇中心。因此,输出的标记比基于分数的方法更准确地表示原始标记集合。0由于这些评分函数在有限的Lipschitz常数下是连续的,特征空间中接近的标记将被分配相似的分数。因此,类似的标记很可能全部保留或丢弃,如图3a所示。正如我们的实验所示,这种冗余(在保留的标记中)和严重的信息损失(在修剪的标记中)会恶化基于分数的下采样方法的计算-准确性权衡。03.3. Attention as a low-pass �lter(3.3.注意力作为低通滤波器)0Given a query vector q , a set of key vectors K = { k 1 , . . . , k N }, the corresponding value vectors V = { v 1 , . . . , v N } and ascalar α > 0 , softmax-attention com- putes the output via(给定查询向量q,一组键向量K = {k1, ..., kN},相应的值向量V = {v1,..., vN}和标量α > 0,通过softmax-attention计算输出)0o ( q ) = 1(2.o(q) = 0i =1 exp( α q ∙ k i ) v i , (6)(∑(e^(αq∙ki))vi,(6))0where z ( q ) = � N i =1 exp( α q ∙ k i ) . Note that we write o ( q ) toindicate that the output vector o is a function of the query q . If thequery vector and all key vectors are normalized to have a �xed ℓ 2norm, we can rewrite (6) as (其中z(q) =∑(e^(αq∙ki)),注意我们写o(q)是为了表示输出向量o是查询q的函数。如果查询向量和所有键向量都被归一化为固定的ℓ2范数,我们可以将(6)重写为)0o ( q ) = 1(o(q) = 1)0z ′ ( q )(2. z'(q))0i =1 exp �α (2.02 ∥ q − k i ∥ 2 � v i(2.2∥q-ki∥^2vi)0z ′ ( q )0� exp � (e^(-α))02 ∥ q − k ∥ 2 ��N � (2.2∥q-k0i =1 δ ( k − k i ) v i(∑δ(k-ki)vi)0�(2. *0z ′ ( q ) G � q 1 (2. z'(q) =G(q; 1))0� � S ( q ; K , V ) , (7) (2. * S(q; K, V),(7))0w (q 0i exp � − α e^(-α))02 ∥ q − k i ∥ 2 � = G � q ; 1(2.2∥q-ki∥^2 = G(q; 1))0α � � S ( q ; K , 1) is thenor- (2.α * S(q; K, 1)是)Pruning tokens inevitably loses information. In this sec-tion, we formulate a new token downsampling principle en-abling strategical tokens selection that preserves the mostinformation. Based on this principle, we formulate and dis-cuss several Token Pooling algorithms.Given a set of output tokens F = {f1, . . . , fN} of atransformer block, our goal is to find a smaller set of to-ℓ(F, ˆF) =�fi∈F∥fi − ˆu(fi; ˆF)∥2.(8)ℓ(F, ˆF) =�fi∈Fminˆfj∈ ˆF∥fi − ˆfj∥2,(9)160Layer Complexity Computation ( 10 9 Flops) (层复杂度计算(10^9次浮点运算))0ViT-B/384 ( N =577 ) ViT-B ( N =197 ) DeiT-S ( N =197 ) DeiT-Ti (N =197 ) (ViT-B/384 (N=577) ViT-B (N=197) DeiT-S (N=197)DeiT-Ti (N=197))0softmax-attn. O ( LN 2 M ) 6.18 0.72 0.36 0.18 QKV proj. O ( LNM 2 ) 12.25 4.18 1.050.26 O proj. O ( LNM 2 ) 4.08 1.39 0.35 0.09 MLP O ( LNM 2 ) 32.67 11.15 2.79 0.70(softmax-attention: O(LN^2M) 6.18 0.72 0.36 0.18, QKV投影: O(LNM^2) 12.25 4.181.05 0.26, O投影: O(LNM^2) 4.08 1.39 0.35 0.09, MLP: O(LNM^2) 32.67 11.15 2.790Total O ( LNM ( M + N )) 55.5 17.6 4.6 1.3 (总计O(LNM(M+N)) 55.5 17.6 4.6 1.3)0Table 1: Time complexity and computation breakdown of ViT [10] and DeiT [43]. L is the number of transformer blocks, N is the number of input tokens, andM is the feature dimensionality. All models take input images of size 224 × 224 except ViT-B/384, which uses 384 × 384 . The softmax-attention layersconstitute a fraction (15% or less) of the total compute, whereas fully-connected layers (MLP and projections) spend over 80%. (表1:ViT [10]和DeiT[43]的时间复杂度和计算分解。L是transformer块的数量,N是输入令牌的数量,M是特征维度。所有模型的输入图像大小为224×224,除了ViT-B/384使用384×384。softmax-attention层占总计算量的一小部分(15%或更少),而全连接层(MLP和投影)占据了80%以上。)0malization scalar function, G � q ; σ 2 � = exp � −∥ q ∥ 2 0is an isometric Gaussian kernel, and S ( q ; K , V ) = � N i =1 δ ( q − ki ) v i is a high-dimensional sparse signal, which is composed of Ndelta functions located at k i with value v i . According to (7), givenquery vectors q 1 , . . . , q N , there are two conceptual steps tocompute softmax-attention: (是等距高斯核,S(q; K, V) =∑δ(q-ki)vi是一个高维稀疏信号,由位于ki处且值为vi的N个δ函数组成。根据(7),给定查询向量q1,...,qN,计算softmax-attention有两个概念步骤:)01. �lter S ( q ; K , V ) with a Gaussian kernel to get o ( q ) , which is ahigh-dimensional continuous signal, and (1.使用高斯核对S(q; K,V)进行滤波,得到o(q),它是一个高维连续信号,并)02. sample o ( q ) at coordinates q 1 , . . . , q N to get the out- put vectorso 1 , . . . , o N . Since Gaussian �ltering is low-pass, o ( q ) is a smooth sig-nal. Therefore, the output tokens of the attention layer, i . e ., discretesamples of o ( q ) , contain similar feature values. In other words, thereexists redundant information in the out- put tokens, and we can prune thisredundancy to without losing much of the important information [37]. Notethat our analysis is based on the normalized query and key vectors, whichcan be achieved by inserting a nor- malizing layer before thesoftmax-attention layer without signi�cantly affecting the performance of atransformer, as demonstrated by Kitaev et al . [18]. It has also been em-pirically observed by Goyal et al . [13] and Rao et al . [34] that even withoutthe normalization, transformers produce tokens with similar values. Todemonstrate this, in all our experiments, we use standard multi-headattention without normalizing keys and queries. We conduct an ablationstudy with normalized keys and queries in Appendix F.(2.在坐标q1,...,qN处对o(q)进行采样,得到输出向量o1,...,oN。由于高斯滤波是低通的,o(q)是一个平滑的信号。因此,注意力层的输出令牌,即o(q)的离散样本,包含相似的特征值。换句话说,输出令牌中存在冗余信息,我们可以通过修剪这些冗余信息而不会丢失太多重要信息[37]。请注意,我们的分析基于归一化的查询和键向量,可以通过在softmax-attention层之前插入一个归一化层来实现,而不会显著影响transformer的性能,正如Kitaev等人[18]所证明的那样。Goyal等人[13]和Rao等人[34]也经验性地观察到,即使没有归一化,transformer生成的令牌也具有相似的值。为了证明这一点,在我们的所有实验中,我们使用标准的多头注意力而不对键和查询进行归一化。我们在附录F中进行了带有归一化键和查询的消融研究。)04. Token pooling(4.令牌池化)0kens ˆ F = { ˆ f 1 , . . . , ˆ f K }使得F由于下采样而产生的重构误差最小。正如我们在第3.3节中所示,我们将F视为连续信号u(f)的离散样本。类似地,ˆF表示另一个信号ˆu(f)。由于K
下载后可阅读完整内容,剩余1页未读,立即下载
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功