vision transformer的class token的作用是什么
时间: 2023-11-24 12:07:29 浏览: 443
Vision Transformer (ViT) 是一种使用自注意力机制(Self-Attention Mechanism)的深度神经网络模型,用于图像分类任务。在 ViT 模型中,Class Token 是一个额外的向量,它被添加到图像的嵌入表示中,然后传递给 Transformer 中的最后一个注意力层。Class Token 的作用是为模型提供一个全局信息的汇总,它捕获了整个图像的语义信息,这有助于模型更好地理解整个图像,并更准确地分类图像。因此,Class Token 是在 ViT 模型中非常重要的组成部分之一。
相关问题
transformer中token
### Transformer 模型中的 Token 处理方式及作用
#### 输入序列构建
在 Transformer 模型中,Token 是基本的输入单元。对于文本数据而言,每个单词或子词会被转换成一个唯一的 ID 或嵌入向量形式的 Token[^1]。
#### 特殊 Tokens 的引入
为了特定目的,一些特殊的 Tokens 也会被添加到输入序列当中。例如,在 BERT (Bidirectional Encoder Representations from Transformers) 中使用的 `[CLS]` 和 `[SEP]` Tokens。其中 `[CLS]` 表示分类任务所需的特殊标记;而 `[SEP]` 则用于区分不同的句子片段[^3]。
#### Class Token 在 Vision Transformer 中的应用
特别地,在视觉任务上的 ViT (Vision Transformer) 架构里,除了常规来自图片切片得到的 patch tokens 外还会额外附加一个 class token。这个 class token 同样参与整个 Transformer 编码过程,并最终仅以其对应的输出部分来进行类别预测工作[^4]。
#### 计算资源影响
当在一个已经存在的输入序列基础上再增添新的 Token 时——无论是普通的还是像上述提到过的那些具有专门用途的——都会使得整体计算负担有所上升。这是因为更多的 Tokens 导致了更大的矩阵运算规模以及更复杂的多头自注意机制运作需求,从而增加了每一步迭代所需的时间开销和硬件资源占用情况。
```python
import torch
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
text = "Using a Transformer network is simple."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
print(output.last_hidden_state.shape)
```
VIsion Transformer
### Vision Transformer 架构详解
Vision Transformer (ViT) 是一种基于纯变换器架构的视觉模型,最初由 Google 团队于 2020 年提出。该模型旨在处理图像数据并执行分类任务,通过将图像分割成多个小块来模仿自然语言处理中的词元化操作[^2]。
#### 图像分片与线性嵌入
输入图像被均匀划分为固定大小的小图块(patches),这些图块随后会被展平为一维向量,并经过线性映射转换为具有相同维度的特征向量。为了保留位置信息,在此阶段还会加入可训练的位置编码[^1]。
```python
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
```
#### 变换器编码层堆叠
得到的一系列带有位置信息的特征向量作为输入传递给一系列相同的变换器编码单元组成的网络。每个编码单元内部包含了多头自注意机制以及前馈神经网络两大部分,二者之间采用残差连接和标准化技术以促进梯度传播[^3]。
```python
class Block(nn.Module):
"""Transformer block."""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop_path_rate=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
# ...其余代码省略...
```
#### 类标记与全局平均池化
在序列最前端添加一个特殊的类别令牌([CLS] token),用于收集整个图片的信息摘要。最终输出时仅需对该令牌对应的隐藏状态做进一步处理即可获得预测结果。此外,还可以选择使用全局平均池化代替类标记来进行表征聚合。
```python
class VisionTransformer(nn.Module):
def __init__(self, ..., classifier='token'):
...
if classifier == 'token':
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
elif classifier == 'gap':
self.has_cls_token = False
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.classifier == "token":
x = x[:, 0]
else:
x = x.mean(dim=1)
return x
```
### 应用场景
除了经典的图像分类外,ViT 还可以应用于其他多种计算机视觉任务中,比如目标检测、语义分割等。由于其强大的表达能力和灵活性,使得 ViT 成为了当前热门的研究方向之一[^4]。
阅读全文
相关推荐
















