1.ViT概述 在上一篇文章中主要讲了 Transformer 的基本原理 ,尤其是在自然语言处理(NLP)任务中的应用,包括编码器和解码器的主要功能和注意力机制的具体实现。但这些内容大多基于 NLP 领域的示例,本篇我们看看在计算机视觉(CV)领域,Transformer 在图像任务中的使用方式。
1.在视觉领域的发展背景
Transformer 模型自 2017 年在 NLP 领域大获成功之后,于 2020 年开始进入视觉领域。第一个具有里程碑意义的工作是 Google 提出的 Vision Transformer (ViT) ,它首次在图像分类任务中展现出优异表现,随后 DETR (由 Facebook AI 提出)也将 Transformer 成功应用于目标检测任务。 这标志着 Transformer 模型正式进军 CV,并在分类、检测、分割等多个任务中取得突破性进展。随着模型结构的不断优化,特别是对计算量的压缩与架构的轻量化,Transformer 已逐步从“只能搞研究”的状态,走向实际应用的阶段。
Transformer 模型在图像任务中的核心挑战是:如何将图像转换为序列化的输入形式 ,以适配 Transformer 的原始结构。
1.图像如何转为序列? Transformer 的输入是一个向量序列,在 NLP 中通过词嵌入(embedding)将词语转化为向量,而在视觉中,我们可以输入二维图像矩阵,而这需要一种方式将其转换为向量序列。通常的做法有以下两种:
直接切分图像块(Patch Embedding) :将图像均匀划分成小块,如 $16\times16$ 大小的 patch,每个 patch 被展平为一个向量,这就是 ViT 的实现逻辑。
使用卷积提取 patch 特征 :对每个 patch 通过多个卷积核提取特征(如使用 512 个卷积核),使每个图像块最终映射为一个高维向量(如 512 维),再组成序列。
这一步的核心目标是:获得一个由向量组成的序列,每个向量代表图像中一个局部区域的特征 。
与卷积神经网络(CNN)相比,Transformer 的显著优势是全局建模能力强 。在 CNN 中,信息传播是局部的、逐层扩散的;而 Transformer 通过 self-attention 机制,在每一层中就能捕捉到任意位置之间的关系,实现了更强的全局依赖建模。 例如:
如果把网络比作公司,CNN 就像是新员工从基层逐步熟悉公司结构;而 Transformer 就像空降高管,上来就能和公司所有部门对接,迅速掌握全局。
Transformer 最终输出的是每个 patch 的全局上下文特征。根据任务的不同,处理方式也不同:
分类任务 :通常使用一个特殊的 [CLS] token 来聚合所有 patch 信息,或对所有位置的向量进行池化,得到全图的表征向量,接分类头输出预测结果。
目标检测、图像分割 :输出的序列会进一步送入解码模块或其他结构中,用于生成 bounding box、mask 或语义标签。
2.ViT整体实现架构 这张图是 Vision Transformer (ViT) 的整体结构示意图,左边是 ViT 的整体流程,右边是每一个 Transformer Encoder 的内部结构。
1.ViT 的整体流程 1.输入图像切分成Patch 原始图像会被均匀地切成多个固定大小的图像块(patches),比如 $16 \times 16$ 像素一块。图中最底下显示了将整幅图像切成了 9 个 patch。
2.每个 patch 展平 + 线性投影 每个 patch 被展平成一个向量(Flatten),比如形状 [3, 16, 16](RGB通道)展平成 3×16×16 = 768 维。然后通过一个线性变换(Linear Projection)投影成固定维度,比如 D = 768,作为 token 向量。
3.加上位置编码Position Embedding 因为 Transformer 本身不理解顺序,所以每个 patch 都要加上对应的 位置编码 (可学习的向量),告诉模型“这是第几个 patch”。 同时,图中还有一个额外的 * patch —— 可学习的 [class] token ,用来代表整张图像的全局信息(类似 BERT 的 [CLS] token),最终用于分类。
所有 patch 向量 + [class] token 一起作为 token 序列,送入 Transformer Encoder。 这个 Encoder 会做自注意力计算,进行信息融合。
5.分类头(MLP Head) 只取最前面的 [class] token 作为全图表达,送入一个简单的 MLP(多层感知机) 来输出分类结果(鸟、球、车等)。
每个 Encoder Block 都是标准的 Transformer 结构,重复堆叠 $L$ 层。
1.子结构
Multi-Head Self-Attention :让每个 patch token 能看到其它所有 patch 的信息。通过注意力机制动态决定“我关注谁”。
Feed Forward(MLP) :每个 token 单独通过一个两层的前馈网络,增加非线性表达能力。
2.其它 LayerNorm 层归一化(Norm)和 残差连接
3.数据预处理 具体实现在models/modeling.py
文件的 Embeddings 类,Embeddings 类负责把图像切成小块(patch),每个小块变成一个向量,加上位置编码 + 特殊的 [CLS] token,最后输出一个 shape 为 [B, N+1, D] 的 token 序列。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 class Embeddings (nn.Module): def __init__ (self, config, img_size, in_channels=3 ): super (Embeddings, self ).__init__() self .hybrid = None img_size = _pair(img_size) """ 1️⃣.判断是否使用 hybrid 模型(是否用 CNN 提特征) 如果用了 grid,说明模型启用了 Hybrid ViT 模式,先用一个 CNN(比如 ResNet) 处理图片,再将 CNN 的输出特征图切成 patch。否则就是标准的 ViT(直接对原始图像 切 patch)。 """ if config.patches.get("grid" ) is not None : grid_size = config.patches["grid" ] patch_size = (img_size[0 ] // 16 // grid_size[0 ], img_size[1 ] // 16 // grid_size[1 ]) n_patches = (img_size[0 ] // 16 ) * (img_size[1 ] // 16 ) self .hybrid = True else : """ 2️⃣.计算 patch 的尺寸和数量 将输入图像大小除以 patch 大小,得到总共切了多少个 patch patch 大小和数量关系着位置编码 shape 和后续 patch embedding 的 shape """ patch_size = _pair(config.patches["size" ]) n_patches = (img_size[0 ] // patch_size[0 ]) * (img_size[1 ] // patch_size[1 ]) self .hybrid = False if self .hybrid: self .hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) in_channels = self .hybrid_model.width * 16 """ 3️⃣.构建 patch embedding 层(Conv2d 模拟 patch + linear) 直接用 Conv2d 进行切 patch + flatten + linear映射三合一操作 输入是原始图像或 CNN 特征图,输出是 shape 为 [B, D, H_patch, W_patch] 的特征 """ self .patch_embeddings = Conv2d( in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size ) """ 4️⃣.构建位置编码与 [CLS] token 位置编码:可学习的位置向量,告诉模型每个 patch 的“位置信息” CLS token:特殊 token,代表整张图的全局语义,最终用于分类 """ self .position_embeddings = nn.Parameter(torch.zeros(1 , n_patches + 1 , config.hidden_size)) self .cls_token = nn.Parameter(torch.zeros(1 , 1 , config.hidden_size)) self .dropout = Dropout(config.transformer["dropout_rate" ]) def forward (self, x ): B = x.shape[0 ] cls_tokens = self .cls_token.expand(B, -1 , -1 ) if self .hybrid: x = self .hybrid_model(x) """ 5️⃣.前向传播:将图像编码为序列 token 图像 → patch → patch embedding → 加 CLS → 加位置信息 → Dropout """ x = self .patch_embeddings(x) x = x.flatten(2 ) x = x.transpose(-1 , -2 ) x = torch.cat((cls_tokens, x), dim=1 ) embeddings = x + self .position_embeddings embeddings = self .dropout(embeddings) return embeddings
供后续 Transformer encoder 处理的输出格式:
1 2 3 4 5 shape: [B, N+1 , D] B:batch_size N:patch数量 +1 :CLS token D:每个 token 的维度(hidden_size)
Embeddings实现逻辑有一个很好的类比:
“把图像裁成一堆小图片(patch),每块都贴上编号(位置编码),再加个代表全图的‘班长token’,然后把它们一起打包发到 transformer 班级里开会。”
具体实现在models/modeling.py
文件的 Embeddings 类的 forward 方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 def forward (self, x ): B = x.shape[0 ] """ 1️⃣.拼接所有patch embedding和position embedding 这三步将所有patch的embedding拼接成一个序列,形状为 [batch, N_patches, hidden] """ x = self .patch_embeddings(x) x = x.flatten(2 ) x = x.transpose(-1 , -2 ) """ 2️⃣.添加CLS token:创建了一个可学习的CLS token,并在序列最前面拼接 """ cls_tokens = self .cls_token.expand(B, -1 , -1 ) x = torch.cat((cls_tokens, x), dim=1 ) """ 3️⃣.加上可学习的位置编码 位置编码和patch embedding逐元素相加,保留了位置信息 """ embeddings = x + self .position_embeddings embeddings = self .dropout(embeddings) return embeddings
patch embedding
和position embedding
的拼接、CLS token
的添加都在Embeddings类的forward方法中实现,最终输出作为Transformer的输入。
这是Transformer编码器的实现,负责将输入的patch序列(含CLS token)通过多层Transformer Block进行特征提取和全局信息交互。
Encoder 类就是多个 Transformer 编码器 Block 的堆叠,并在最后加一个 LayerNorm,最后输出编码后的 token 和(可选)注意力权重。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 class Encoder (nn.Module): def __init__ (self, config, vis ): super (Encoder, self ).__init__() self .vis = vis self .layer = nn.ModuleList() self .encoder_norm = LayerNorm(config.hidden_size, eps=1e-6 ) for _ in range (config.transformer["num_layers" ]): layer = Block(config, vis) self .layer.append(copy.deepcopy(layer)) def forward (self, hidden_states ): attn_weights = [] for layer_block in self .layer: hidden_states, weights = layer_block(hidden_states) if self .vis: attn_weights.append(weights) encoded = self .encoder_norm(hidden_states) return encoded, attn_weights
可以把这个 Encoder 类比成一个“专家委员会”,由多个“专家(Block)”组成,图像每个 patch 是一个“提案”,他们在每一轮里互相交换信息、总结观点,最终得出一份“共识文档(encoded)”。
2.Block结构 这个 Block 类是 Vision Transformer 的一个标准 Transformer 编码器子层,包括:
一个 多头自注意力(Multi-head Self-Attention)模块
一个 前馈神经网络(MLP)
两个 LayerNorm
两个 残差连接(Residual Add)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 class Block (nn.Module): def __init__ (self, config, vis ): super (Block, self ).__init__() self .hidden_size = config.hidden_size self .attention_norm = LayerNorm(config.hidden_size, eps=1e-6 ) self .ffn_norm = LayerNorm(config.hidden_size, eps=1e-6 ) self .ffn = Mlp(config) self .attn = Attention(config, vis) def forward (self, x ): print (x.shape) h = x x = self .attention_norm(x) x, weights = self .attn(x) x = x + h h = x x = self .ffn_norm(x) x = self .ffn(x) x = x + h return x, weights def load_from (self, weights, n_block ): ROOT = f"Transformer/encoderblock_{n_block} " with torch.no_grad(): query_weight = np2th(weights[ROOT + "/" + ATTENTION_Q + "/kernel" ]).view(self .hidden_size, self .hidden_size).t() key_weight = ... value_weight = ... out_weight = ... query_bias = ... key_bias = ... value_bias = ... out_bias = ... self .attn.query.weight.copy_(query_weight) self .attn.key.weight.copy_(key_weight) mlp_weight_0 = np2th(weights[ROOT + "/" + FC_0 + "/kernel" ]).t() mlp_weight_1 = np2th(weights[ROOT + "/" + FC_1 + "/kernel" ]).t() mlp_bias_0 = ... mlp_bias_1 = ... self .ffn.fc1.weight.copy_(mlp_weight_0) self .ffn.fc2.weight.copy_(mlp_weight_1) self .attention_norm.weight.copy_(np2th(weights[ROOT + "/" + ATTENTION_NORM + "/scale" ])) self .attention_norm.bias.copy_(np2th(weights[ROOT + "/" + ATTENTION_NORM + "/bias" ])) self .ffn_norm.weight.copy_(...)
Block结构图:
1 2 3 4 5 6 7 8 9 10 11 12 13 输入 x (B, N, D) ↓ LayerNorm(层归一化) ↓ Residual Add(残差连接) ↓ LayerNorm(层归一化) ↓ FFN(MLP)(前馈神经网络) ↓ Residual Add(残差连接) ↓ 输出 x
这是整个ViT模型的顶层封装,负责将输入图片经过embedding、Transformer编码器、分类头等完整流程,输出最终的分类结果。主要包含
完整的 Transformer 编码器结构
线性分类头 head
从预训练权重中加载参数的load_from() 方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 class VisionTransformer (nn.Module): def __init__ (self, config, img_size=224 , num_classes=21843 , zero_head=False , vis=False ): super (VisionTransformer, self ).__init__() self .num_classes = num_classes self .zero_head = zero_head self .classifier = config.classifier self .transformer = Transformer(config, img_size, vis) self .head = Linear(config.hidden_size, num_classes) def forward (self, x, labels=None ): x, attn_weights = self .transformer(x) logits = self .head(x[:, 0 ]) print (logits.shape) if labels is not None : loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1 , self .num_classes), labels.view(-1 )) return loss else : return logits, attn_weights def load_from (self, weights ): with torch.no_grad(): if self .zero_head: nn.init.zeros_(self .head.weight) nn.init.zeros_(self .head.bias) else : self .head.weight.copy_(np2th(weights["head/kernel" ]).t()) self .head.bias.copy_(np2th(weights["head/bias" ]).t()) self .transformer.embeddings.patch_embeddings.weight. copy_(np2th(weights["embedding/kernel" ], conv=True )) self .transformer.embeddings.patch_embeddings.bias. copy_(np2th(weights["embedding/bias" ])) self .transformer.embeddings.cls_token. copy_(np2th(weights["cls" ])) self .transformer.encoder.encoder_norm.weight. copy_(np2th(weights["Transformer/encoder_norm/scale" ])) self .transformer.encoder.encoder_norm.bias. copy_(np2th(weights["Transformer/encoder_norm/bias" ])) posemb = np2th(weights["Transformer/posembed_input/pos_embedding" ]) posemb_new = self .transformer.embeddings.position_embeddings if posemb.size() == posemb_new.size(): self .transformer.embeddings.position_embeddings.copy_(posemb) else : logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) ntok_new = posemb_new.size(1 ) if self .classifier == "token" : posemb_tok, posemb_grid = posemb[:, :1 ], posemb[0 , 1 :] ntok_new -= 1 else : posemb_tok, posemb_grid = posemb[:, :0 ], posemb[0 ] gs_old = int (np.sqrt(len (posemb_grid))) gs_new = int (np.sqrt(ntok_new)) print ('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1 ) zoom = (gs_new / gs_old, gs_new / gs_old, 1 ) posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1 ) posemb_grid = posemb_grid.reshape(1 , gs_new * gs_new, -1 ) posemb = np.concatenate([posemb_tok, posemb_grid], axis=1 ) self .transformer.embeddings.position_embeddings. copy_(np2th(posemb)) for bname, block in self .transformer.encoder.named_children(): for uname, unit in block.named_children(): unit.load_from(weights, n_block=uname) if self .transformer.embeddings.hybrid:
结构图:
1 2 3 4 5 6 VisionTransformer ├── Transformer │ ├── Embeddings(Patch + Pos + CLS) │ ├── Encoder(多个 Block + LayerNorm) ├── head(分类器) └── load_from(支持不同输入尺寸权重加载 + hybrid 模块加载)
7.总结 本文主要介绍了Vision Transformer (ViT) 的主要实现,包括:
ViT概述 :Transformer从NLP扩展到CV领域,ViT是首个在图像分类中表现优异的Transformer模型。核心思路:将图像切分为patch序列,通过Transformer处理,利用self-attention的全局建模能力替代CNN的局部感受野
ViT架构 :
图像分块 :将输入图像切分为固定大小的patch
Patch Embedding :每个patch展平后通过线性投影得到向量表示
位置编码 :添加可学习的位置编码保留空间信息
[CLS] Token :添加特殊token用于分类任务
Transformer Encoder :多层自注意力机制和前馈网络堆叠
关键实现 :
Embedding层 :处理图像分块、位置编码和CLS token
Transformer Encoder :包含多头自注意力和MLP,使用LayerNorm和残差连接
分类头 :基于CLS token输出分类结果
8.备注