ViT — Transformer在视觉领域应用代码解析

1.ViT概述

在上一篇文章中主要讲了 Transformer 的基本原理,尤其是在自然语言处理(NLP)任务中的应用,包括编码器和解码器的主要功能和注意力机制的具体实现。但这些内容大多基于 NLP 领域的示例,本篇我们看看在计算机视觉(CV)领域,Transformer 在图像任务中的使用方式。

1.在视觉领域的发展背景

Transformer 模型自 2017 年在 NLP 领域大获成功之后,于 2020 年开始进入视觉领域。第一个具有里程碑意义的工作是 Google 提出的 Vision Transformer (ViT),它首次在图像分类任务中展现出优异表现,随后 DETR(由 Facebook AI 提出)也将 Transformer 成功应用于目标检测任务。
这标志着 Transformer 模型正式进军 CV,并在分类、检测、分割等多个任务中取得突破性进展。随着模型结构的不断优化,特别是对计算量的压缩与架构的轻量化,Transformer 已逐步从“只能搞研究”的状态,走向实际应用的阶段。

2.视觉任务中 Transformer 的基本思路

Transformer 模型在图像任务中的核心挑战是:如何将图像转换为序列化的输入形式,以适配 Transformer 的原始结构。

1.图像如何转为序列?

Transformer 的输入是一个向量序列,在 NLP 中通过词嵌入(embedding)将词语转化为向量,而在视觉中,我们可以输入二维图像矩阵,而这需要一种方式将其转换为向量序列。通常的做法有以下两种:

  • 直接切分图像块(Patch Embedding):将图像均匀划分成小块,如 $16\times16$ 大小的 patch,每个 patch 被展平为一个向量,这就是 ViT 的实现逻辑。
  • 使用卷积提取 patch 特征:对每个 patch 通过多个卷积核提取特征(如使用 512 个卷积核),使每个图像块最终映射为一个高维向量(如 512 维),再组成序列。

这一步的核心目标是:获得一个由向量组成的序列,每个向量代表图像中一个局部区域的特征

2.为什么使用 Transformer?

与卷积神经网络(CNN)相比,Transformer 的显著优势是全局建模能力强。在 CNN 中,信息传播是局部的、逐层扩散的;而 Transformer 通过 self-attention 机制,在每一层中就能捕捉到任意位置之间的关系,实现了更强的全局依赖建模。
例如:

如果把网络比作公司,CNN 就像是新员工从基层逐步熟悉公司结构;而 Transformer 就像空降高管,上来就能和公司所有部门对接,迅速掌握全局。

3.Transformer 输出的用途

Transformer 最终输出的是每个 patch 的全局上下文特征。根据任务的不同,处理方式也不同:

  • 分类任务:通常使用一个特殊的 [CLS] token 来聚合所有 patch 信息,或对所有位置的向量进行池化,得到全图的表征向量,接分类头输出预测结果。
  • 目标检测、图像分割:输出的序列会进一步送入解码模块或其他结构中,用于生成 bounding box、mask 或语义标签。

2.ViT整体实现架构


这张图是 Vision Transformer (ViT) 的整体结构示意图,左边是 ViT 的整体流程,右边是每一个 Transformer Encoder 的内部结构。

1.ViT 的整体流程

1.输入图像切分成Patch

原始图像会被均匀地切成多个固定大小的图像块(patches),比如 $16 \times 16$ 像素一块。图中最底下显示了将整幅图像切成了 9 个 patch。

2.每个 patch 展平 + 线性投影

每个 patch 被展平成一个向量(Flatten),比如形状 [3, 16, 16](RGB通道)展平成 3×16×16 = 768 维。然后通过一个线性变换(Linear Projection)投影成固定维度,比如 D = 768,作为 token 向量。

3.加上位置编码Position Embedding

因为 Transformer 本身不理解顺序,所以每个 patch 都要加上对应的 位置编码(可学习的向量),告诉模型“这是第几个 patch”。
同时,图中还有一个额外的 * patch —— 可学习的 [class] token,用来代表整张图像的全局信息(类似 BERT 的 [CLS] token),最终用于分类。

4.输入 Transformer 编码器

所有 patch 向量 + [class] token 一起作为 token 序列,送入 Transformer Encoder。
这个 Encoder 会做自注意力计算,进行信息融合。

5.分类头(MLP Head)

只取最前面的 [class] token 作为全图表达,送入一个简单的 MLP(多层感知机) 来输出分类结果(鸟、球、车等)。

2.Transformer Encoder 的结构细节

每个 Encoder Block 都是标准的 Transformer 结构,重复堆叠 $L$ 层。

1.子结构

  1. Multi-Head Self-Attention:让每个 patch token 能看到其它所有 patch 的信息。通过注意力机制动态决定“我关注谁”。
  2. Feed Forward(MLP):每个 token 单独通过一个两层的前馈网络,增加非线性表达能力。

2.其它

LayerNorm 层归一化(Norm)和 残差连接

3.数据预处理

具体实现在models/modeling.py 文件的 Embeddings 类,Embeddings 类负责把图像切成小块(patch),每个小块变成一个向量,加上位置编码 + 特殊的 [CLS] token,最后输出一个 shape 为 [B, N+1, D] 的 token 序列。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class Embeddings(nn.Module):
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
img_size = _pair(img_size)

"""
1️⃣.判断是否使用 hybrid 模型(是否用 CNN 提特征)
如果用了 grid,说明模型启用了 Hybrid ViT 模式,先用一个 CNN(比如 ResNet)
处理图片,再将 CNN 的输出特征图切成 patch。否则就是标准的 ViT(直接对原始图像
切 patch)。
"""
if config.patches.get("grid") is not None:
grid_size = config.patches["grid"]
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) # 计算 patch 大小
n_patches = (img_size[0] // 16) * (img_size[1] // 16) # patch 总数量(H/16 × W/16)
self.hybrid = True
else:
"""
2️⃣.计算 patch 的尺寸和数量
将输入图像大小除以 patch 大小,得到总共切了多少个 patch
patch 大小和数量关系着位置编码 shape 和后续 patch embedding 的 shape
"""
patch_size = _pair(config.patches["size"]) # 直接使用指定的 patch 大小
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) # patch 总数


self.hybrid = False


# 如果启用 hybrid 模式,就用 ResNetV2 先提特征
if self.hybrid:
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
width_factor=config.resnet.width_factor)
# 输出特征通道数:ResNet 的输出通道 * 16(因为是最后 stage 输出)
in_channels = self.hybrid_model.width * 16

"""
3️⃣.构建 patch embedding 层(Conv2d 模拟 patch + linear)
直接用 Conv2d 进行切 patch + flatten + linear映射三合一操作
输入是原始图像或 CNN 特征图,输出是 shape 为 [B, D, H_patch, W_patch] 的特征
"""
self.patch_embeddings = Conv2d(
in_channels=in_channels,
out_channels=config.hidden_size, # 投影后的向量维度
kernel_size=patch_size,
stride=patch_size
)


"""
4️⃣.构建位置编码与 [CLS] token
位置编码:可学习的位置向量,告诉模型每个 patch 的“位置信息”
CLS token:特殊 token,代表整张图的全局语义,最终用于分类
"""
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, config.hidden_size))
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))


# dropout 进行 regularization
self.dropout = Dropout(config.transformer["dropout_rate"])

def forward(self, x):
B = x.shape[0] # 批量大小
# 扩展 cls token,复制为 B 个样本,每个 shape 为 [1, 1, D] → [B, 1, D]
cls_tokens = self.cls_token.expand(B, -1, -1)

# 如果是 hybrid 模式,则先用 ResNet 提取特征
if self.hybrid:
x = self.hybrid_model(x)


"""
5️⃣.前向传播:将图像编码为序列 token
图像 → patch → patch embedding → 加 CLS → 加位置信息 → Dropout
"""
# 用 Conv2d 切 patch 并做线性映射,输出 shape 为 [B, D, H_patch, W_patch]
x = self.patch_embeddings(x)
# 将空间维度展平:[B, D, H_patch, W_patch] → [B, D, N_patches]
x = x.flatten(2)
# 调整维度顺序:[B, D, N] → [B, N, D],符合 Transformer 输入格式
x = x.transpose(-1, -2)
# 拼接 [CLS] token 在序列最前面,得到 [B, N+1, D]
x = torch.cat((cls_tokens, x), dim=1)
# 加上位置编码(Broadcast 自动匹配)
embeddings = x + self.position_embeddings
# 加入 dropout 以防止过拟合
embeddings = self.dropout(embeddings)

return embeddings # 最终输出为 [B, N+1, D],作为 Transformer 的输入

供后续 Transformer encoder 处理的输出格式:

1
2
3
4
5
shape: [B, N+1, D]  
B:batch_size  
N:patch数量  
+1:CLS token  
D:每个 token 的维度(hidden_size)

Embeddings实现逻辑有一个很好的类比:

“把图像裁成一堆小图片(patch),每块都贴上编号(位置编码),再加个代表全图的‘班长token’,然后把它们一起打包发到 transformer 班级里开会。”

4.构建Transformer输入

具体实现在models/modeling.py 文件的 Embeddings 类的 forward 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def forward(self, x):
B = x.shape[0]

"""
1️⃣.拼接所有patch embedding和position embedding
这三步将所有patch的embedding拼接成一个序列,形状为 [batch, N_patches, hidden]
"""
x = self.patch_embeddings(x) # [B, hidden, H_patch, W_patch]
x = x.flatten(2) # [B, hidden, N_patches]
x = x.transpose(-1, -2) # [B, N_patches, hidden]


"""
2️⃣.添加CLS token:创建了一个可学习的CLS token,并在序列最前面拼接
"""
cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, hidden]
x = torch.cat((cls_tokens, x), dim=1) # [B, N_patches+1, hidden]


"""
3️⃣.加上可学习的位置编码
位置编码和patch embedding逐元素相加,保留了位置信息
"""
embeddings = x + self.position_embeddings # [B, N_patches+1, hidden]

embeddings = self.dropout(embeddings)
return embeddings

patch embeddingposition embedding的拼接、CLS token的添加都在Embeddings类的forward方法中实现,最终输出作为Transformer的输入。

5.Transformer编码器

这是Transformer编码器的实现,负责将输入的patch序列(含CLS token)通过多层Transformer Block进行特征提取和全局信息交互。

1.Transformer Encoder

Encoder 类就是多个 Transformer 编码器 Block 的堆叠,并在最后加一个 LayerNorm,最后输出编码后的 token 和(可选)注意力权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Encoder(nn.Module):  
def __init__(self, config, vis):
super(Encoder, self).__init__()

# 1️⃣.初始化基础组件
self.vis = vis # 是否启用注意力可视化
self.layer = nn.ModuleList() # 一个装有多层 Transformer Block 的模块列表
# 对最终输出做 LayerNorm
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)

# 2️⃣.构建多个 Transformer Block(循环添加)
for _ in range(config.transformer["num_layers"]):
# 每层是一个完整的 Transformer Block(带 attention 和 MLP)
layer = Block(config, vis)
# 使用 deepcopy 确保每层独立,参数不共享
self.layer.append(copy.deepcopy(layer))

# hidden_states 是输入序列,一般 shape 是 [B, N, D]
def forward(self, hidden_states):

# 3️⃣.前向传播:依次通过每个 Transformer Block
attn_weights = [] # 存储所有层的注意力矩阵(仅在 vis=True 时启用)
for layer_block in self.layer:
# 输入当前 hidden_states,输出更新后的状态和注意力权重
hidden_states, weights = layer_block(hidden_states)
if self.vis: # 如果启用了 vis,则会收集每一层的注意力权重(用于可视化)
attn_weights.append(weights)

# 4️⃣.对最终的输出 hidden_states 做 LayerNorm,作为输出特征
encoded = self.encoder_norm(hidden_states) # [B, N+1, D] → 标准化处理
return encoded, attn_weights

可以把这个 Encoder 类比成一个“专家委员会”,由多个“专家(Block)”组成,图像每个 patch 是一个“提案”,他们在每一轮里互相交换信息、总结观点,最终得出一份“共识文档(encoded)”。

2.Block结构

这个 Block 类是 Vision Transformer 的一个标准 Transformer 编码器子层,包括:

  • 一个 多头自注意力(Multi-head Self-Attention)模块
  • 一个 前馈神经网络(MLP)
  • 两个 LayerNorm
  • 两个 残差连接(Residual Add)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Block(nn.Module):  
def __init__(self, config, vis):
super(Block, self).__init__()
# 模型隐藏维度(通常是 768)
self.hidden_size = config.hidden_size
# Attention 前的 LayerNorm(用于稳定训练)
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
# FFN 前的 LayerNorm
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
# MLP 模块(含两层全连接层 + GELU)
self.ffn = Mlp(config)
# 多头注意力模块(包含 query/key/value + heads 合并等)
self.attn = Attention(config, vis)

def forward(self, x):
print(x.shape) # 输入 shape:[B, N, D]

# 1️⃣.自注意力
h = x # 保存输入,用于残差连接
x = self.attention_norm(x) # LayerNorm 正则化
x, weights = self.attn(x) # 多头自注意力计算,返回输出和注意力矩阵
x = x + h # 残差连接:加回原始输入

# 2️⃣.前馈网络
h = x # 保存当前输入
x = self.ffn_norm(x) # LayerNorm
x = self.ffn(x) # 前馈神经网络:FC1 → GELU → Dropout → FC2 → Dropout
x = x + h # 再次残差连接
return x, weights # 返回更新后的特征和注意力权重(用于可视化)

# 权重加载部分: 从预训练模型中手动加载权重
def load_from(self, weights, n_block):
ROOT = f"Transformer/encoderblock_{n_block}"
with torch.no_grad(): # 在权重加载过程中禁用梯度计算,加快速度并避免误更新
# 1️⃣.Attention 权重加载
# 从 .npz 权重中取出 Q/K/V/Out 的权重并转置适配 PyTorch 格式
query_weight = np2th(weights[ROOT + "/" + ATTENTION_Q + "/kernel"]).view(self.hidden_size, self.hidden_size).t()
key_weight = ...
value_weight = ...
out_weight = ...
# 加载偏置
query_bias = ...
key_bias = ...
value_bias = ...
out_bias = ...
# 将权重拷贝到 attention 模块的对应参数
self.attn.query.weight.copy_(query_weight)
self.attn.key.weight.copy_(key_weight)

# 2️⃣.MLP 权重加载
# FC1 权重
mlp_weight_0 = np2th(weights[ROOT + "/" + FC_0 + "/kernel"]).t() 
# FC2 权重
mlp_weight_1 = np2th(weights[ROOT + "/" + FC_1 + "/kernel"]).t() 
mlp_bias_0 = ...
mlp_bias_1 = ...

self.ffn.fc1.weight.copy_(mlp_weight_0)
self.ffn.fc2.weight.copy_(mlp_weight_1)

# 3️⃣.LayerNorm 权重加载
# 加载 attention_norm 和 ffn_norm 的 gamma 和 beta(scale 和 bias)
self.attention_norm.weight.copy_(np2th(weights[ROOT + "/" + ATTENTION_NORM + "/scale"]))
self.attention_norm.bias.copy_(np2th(weights[ROOT + "/" + ATTENTION_NORM + "/bias"]))
self.ffn_norm.weight.copy_(...)

Block结构图:

1
2
3
4
5
6
7
8
9
10
11
12
13
  输入 x (B, N, D)

LayerNorm(层归一化)

Residual Add(残差连接)

LayerNorm(层归一化)

FFN(MLP)(前馈神经网络)

Residual Add(残差连接)

输出 x

6.VisionTransformer

这是整个ViT模型的顶层封装,负责将输入图片经过embedding、Transformer编码器、分类头等完整流程,输出最终的分类结果。主要包含

  • 完整的 Transformer 编码器结构
  • 线性分类头 head
  • 从预训练权重中加载参数的load_from() 方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class VisionTransformer(nn.Module):  
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()

self.num_classes = num_classes # 分类类别数(如 1000 或 21843)
self.zero_head = zero_head # 是否在初始化时将分类头置为 0
self.classifier = config.classifier # 分类策略(如 "token")

        # 创建 Transformer 主体(包含 Embedding、Encoder)
self.transformer = Transformer(config, img_size, vis)

# 分类头:将 CLS token 映射为类别预测值
self.head = Linear(config.hidden_size, num_classes)


def forward(self, x, labels=None):
# 输入图片 x,通过 Transformer 得到 token 特征(含 CLS token)
x, attn_weights = self.transformer(x)

# 取出 CLS token 向量(即 x 的第 0 个 token),做分类预测
logits = self.head(x[:, 0])
print(logits.shape)

if labels is not None:
# 如果有标签,则计算 loss(训练模式)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
# 推理模式:返回 logits 和注意力权重
return logits, attn_weights

# 权重加载方法
def load_from(self, weights):
with torch.no_grad(): # 禁止梯度计算,直接加载权重
# 1.加载分类头参数
if self.zero_head:
# 若启用 zero_head,则将分类头初始化为全 0(不使用预训练分类头)
nn.init.zeros_(self.head.weight)
nn.init.zeros_(self.head.bias)
else:
# 否则从预训练权重中加载分类头
self.head.weight.copy_(np2th(weights["head/kernel"]).t())
self.head.bias.copy_(np2th(weights["head/bias"]).t())

# 2.加载 patch embedding 模块参数(卷积 + 位置编码 + CLS token)
# 加载 patch embedding 卷积核
self.transformer.embeddings.patch_embeddings.weight.
copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.
copy_(np2th(weights["embedding/bias"]))
# 加载 [CLS] token 向量
self.transformer.embeddings.cls_token.
copy_(np2th(weights["cls"]))
# 加载最后 encoder norm 层的参数
self.transformer.encoder.encoder_norm.weight.
copy_(np2th(weights["Transformer/encoder_norm/scale"]))
self.transformer.encoder.encoder_norm.bias.
copy_(np2th(weights["Transformer/encoder_norm/bias"]))

# 3.加载位置编码(含 resize 逻辑)
# 加载位置编码
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
posemb_new = self.transformer.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
# 尺寸一致,直接加载
self.transformer.embeddings.position_embeddings.copy_(posemb)
else:
# 若尺寸不一致,则插值 resize(一般因输入图片大小不同)
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)

if self.classifier == "token":
# token 分类方式:分离 [CLS] 和 patch 部分
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
else:
# GAP 分类方式:不包含 token
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

# 原位置编码网格大小
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
# 将位置编码 reshape 成二维图像再缩放
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)

# 合并 [CLS] token 和缩放后的位置编码
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) self.transformer.embeddings.position_embeddings.
copy_(np2th(posemb))

# 4.加载 Transformer 中每一层 Block 的参数
for bname, block in self.transformer.encoder.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)

# 5.加载 hybrid CNN 模块的参数(若启用)
if self.transformer.embeddings.hybrid:
# ...

结构图:

1
2
3
4
5
6
VisionTransformer
├── Transformer
│ ├── Embeddings(Patch + Pos + CLS)
│ ├── Encoder(多个 Block + LayerNorm)
├── head(分类器)
└── load_from(支持不同输入尺寸权重加载 + hybrid 模块加载)

7.总结

本文主要介绍了Vision Transformer (ViT) 的主要实现,包括:

  1. ViT概述​:Transformer从NLP扩展到CV领域,ViT是首个在图像分类中表现优异的Transformer模型。核心思路:将图像切分为patch序列,通过Transformer处理,利用self-attention的全局建模能力替代CNN的局部感受野
  2. ViT架构​:
    • 图像分块:将输入图像切分为固定大小的patch
    • Patch Embedding:每个patch展平后通过线性投影得到向量表示
    • 位置编码:添加可学习的位置编码保留空间信息
    • [CLS] Token:添加特殊token用于分类任务
    • Transformer Encoder:多层自注意力机制和前馈网络堆叠
  3. 关键实现​:
    • Embedding层:处理图像分块、位置编码和CLS token
    • Transformer Encoder:包含多头自注意力和MLP,使用LayerNorm和残差连接
    • 分类头:基于CLS token输出分类结果

8.备注