U²-Net显著性目标检测

1.U²-Net介绍与应用

图像分割与U-Net系列模型解析基于U-Net++的细胞分割代码实现 中提到了U-Net系列网络模型,而 U²-Net 虽然是一个U-Net的变体版本,原本用于显著性检测任务,但由于其优异的前景提取能力,逐渐被广泛用于抠图、图像编辑、人像分割等任务中。

1.U²-Net 概述

U²-Net 属于“显著性检测”任务中的网络结构,其核心目标是从图像中识别出前景区域,即显著目标(Salient Object Detection, SOD)。从任务定义来看,它本质上和语义分割非常接近,将图像划分为前景和背景,只是语义标签通常只有两类。

显著性检测(Salient Object Detection)指的是从一幅图像中识别出“最引人注意的”区域,通常即为前景区域。模型的输入是一张图像,输出是一张二值或灰度图,其中白色表示前景(显著区域),黑色表示背景。
例如:一只狗在草地上跑动,狗即为显著对象;一群人站在海滩上,人是前景,海滩是背景。

2.应用

如下图,在肖像素描任务生成中,输入是一张彩色人像图像,输出则像是一幅素描画,纹理和细节都能够得到很好的保留,效果相比很多图像生成或风格转换方法更为自然一些,甚至比一些 GAN 模型生成的还要好,特别是人脸细节部分,U²-Net 在轮廓、皱纹、发丝等区域的提取的也非常准确。

除此之外,U²-Net 在抠图(前景分离)任务中也有不错的效果,无论背景复杂与否,它都能较好地将人物主体从图像中分离出来,这在直播、虚拟背景替换、美颜等应用场景中有着非常广泛的应用价值。例如在某些直播中,主播本来需要一张绿幕才能实现实时背景替换,而利用U²-Net这一类模型就可以做到“无绿幕”抠图,显著提升了便捷性和效果稳定性。

U²-Net 的模型结构不仅能够提取精确的前景,还能为后续处理任务提供干净的语义掩码(mask),这为实时视频处理、特效添加、风格迁移等任务打下了基础,要具备这些能力并非仅仅依赖于“更深的U结构”,而是得益于其独特的设计思想。

2. U²-Net网络架构

U²-Net 的整体结构采用了 Encoder-Decoder(编码器-解码器) 框架,但不同于传统 U-Net,它在每个阶段都使用了一个更复杂的模块——RSU(Residual U-block),也就是双层嵌套U型结构(U-in-U)
整个网络由:

  • 6 个编码模块(En_1 到 En_6)
  • 5 个解码模块(De_1 到 De_5)
  • 多个深监督输出(Sup1 ~ Sup6)
  • 一个最终融合输出(S_fuse)

最终组成一个对称且层层嵌套的 U² 结构。如下图所示

1.RSU 模块

RSU 模块是 U²-Net 的核心模块,全称为 Residual U-shaped block,如上图每个 En_x 和 De_x 模块其实并不是单纯的卷积块,而是一个 RSU 模块,其内部包含一个编码-解码结构(小型 U-Net),然后使用多个空洞卷积扩展感受野,同时保持输出尺寸不变,最终将输入与输出进行残差连接,强化特征传递与复用。

这样做的好处是能有效平衡模型深度与参数量,在保证性能的同时提高了效率,从而增强局部特征提取与上下文信息聚合的能力。

1.基础组件:REBNCONV

RSU 模块的基础单元是 REBNCONV,它由以下部分组成:

  1. 卷积层:3×3 卷积,支持空洞卷积(dilation)
  2. 批归一化:BatchNorm2d,用于稳定训练
  3. ReLU激活:ReLU(inplace=True),用于非线性变换
1
2
3
4
5
6
7
8
9
10
11
12
13
class REBNCONV(nn.Module):  
def __init__(self,in_ch=3,out_ch=3,dirate=1):
super(REBNCONV,self).__init__()

self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)

def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

return xout

2.RSU 模块结构

以 RSU7 为例,RSU 模块包含以下部分:

1.编码器路径(Encoder Path)

输入首先通过一个 REBNCONV 层(rebnconvin),然后依次通过多个 REBNCONV 层和最大池化层(MaxPool2d),逐步降低特征图的空间分辨率,同时增加通道数,最后通过一个空洞卷积(dirate=2)进一步提取特征。

2 解码器路径(Decoder Path)

从最深层开始,通过上采样 _upsample_like 和 REBNCONV 层逐步恢复空间分辨率,每一步都会将当前特征与编码器对应层的特征拼接(torch.cat),形成跳跃连接(skip connection)。

3.残差连接

最终输出是解码器路径的最后一层与输入特征(hxin)的残差和(hx1d + hxin),这有助于梯度流动和特征保留。

4.具体实现(以 RSU7 为例)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7,self).__init__()
# 输入层
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
# 编码器路径
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
# 解码器路径
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
# 编码器路径
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
# 解码器路径
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
# 残差连接
return hx1d + hxin
5. RSU 模块的变体

U²-Net 中定义了多种 RSU 模块变体(如 RSU7、RSU6、RSU5、RSU4、RSU4F),它们的区别主要在于:

  • 深度:编码器路径的层数(如 RSU7 有7层,RSU6 有6层)
  • 通道数:中间层(mid_ch)和输出层(out_ch)的通道数
  • 空洞卷积:最深层是否使用空洞卷积(dirate=2)
6. 作用与优势
  • 多尺度特征提取:通过编码器-解码器结构和跳跃连接,RSU 模块能够同时捕获局部和全局特征
  • 残差学习:残差连接有助于梯度流动,避免梯度消失问题
  • 灵活性:不同深度的 RSU 模块可以适应不同复杂度的任务

2.深度监督

深度监督(Deep Supervision)指的是在网络训练时,不仅对最终输出计算损失,而是在中间多个层也加入监督信号,引导模型在多个尺度和层级上学习有意义的特征。

U²-Net 的深度监督(Deep Supervision)机制通过以下步骤实现:

1. 多输出设计

U²-Net 在解码器的每个阶段(stage)都输出一个显著性预测图(side output),这些输出分别对应不同尺度的特征图。模型的前向传播返回了7个输出:

1
2
3
# d0 是最终输出(主输出)
# d1 到 d6 是中间层的输出(side outputs)
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)

2. 损失函数定义

用于计算每个输出的二值交叉熵损失(BCE Loss)

1
2
3
4
5
6
7
8
9
10
11
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)

loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
return loss0, loss

每个输出(d0 到 d6)都与 ground truth(labels_v)计算 BCE 损失,最终的总损失是所有损失的简单相加(权重相等)。

3. 训练流程中的深度监督

在训练循环中,模型的前向传播返回多个输出,然后通过 muti_bce_loss_fusion 计算损失:

1
2
3
4
5
6
7
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
# loss2 是主输出(d0)的损失,用于监控训练效果
# loss 是所有输出的总损失,用于反向传播和参数更新
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

loss.backward()
optimizer.step()

4. 作用与优势

  • 梯度流动:中间层的输出直接参与损失计算,有助于梯度更好地传递到网络深层,缓解梯度消失问题
  • 多尺度监督:不同尺度的输出能够捕获不同层次的特征,提升模型的泛化能力
  • 加速收敛:多输出的监督信号能够加速模型收敛

3.多尺度融合

显著性目标可能大小不一、形态复杂。单一尺度难以同时兼顾,所以需要关注边缘、纹理、细节的低层特征和 关注语义、上下文、全局结构的高层特征

U²-Net 的多尺度融合(Multi-scale Fusion)主要通过以下步骤实现:

1. 多尺度特征提取

U²-Net 的编码器路径(Encoder Path)通过多个 RSU 模块(如 RSU7、RSU6、RSU5 等)逐步提取不同尺度的特征。每个 RSU 模块的输出特征图尺寸逐渐减小,但通道数逐渐增加,从而捕获不同层次的信息。

1
2
3
4
5
6
self.stage1 = RSU7(in_ch, 32, 64) # 输入 -> 64通道
self.stage2 = RSU6(64, 32, 128) # 64通道 -> 128通道
self.stage3 = RSU5(128, 64, 256) # 128通道 -> 256通道
self.stage4 = RSU4(256, 128, 512) # 256通道 -> 512通道
self.stage5 = RSU4F(512, 256, 512) # 512通道 -> 512通道
self.stage6 = RSU4F(512, 256, 512) # 512通道 -> 512通道

2. 解码器路径与跳跃连接

解码器路径(Decoder Path)通过上采样_upsample_like和跳跃连接(skip connection)逐步恢复空间分辨率,同时融合不同尺度的特征:

1
2
3
4
5
6
7
8
9
10
# 解码器路径
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) # 融合 hx6up 和 hx5
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) # 融合 hx5dup 和 hx4
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) # 融合 hx4dup 和 hx3
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) # 融合 hx3dup 和 hx2
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) # 融合 hx2dup 和 hx1

torch.cat((hx6up, hx5), 1) 将上采样后的特征与编码器对应层的特征拼接,形成跳跃连接
每个解码器阶段(stage5d、stage4d 等)进一步处理融合后的特征。

3. 多输出与最终融合

U²-Net 在解码器的每个阶段都输出一个显著性预测图(side output),这些输出通过上采样调整到相同分辨率,然后拼接并融合为最终输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 多输出
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)

# 最终融合
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

self.side1 到 self.side6 是 1×1 卷积层,用于将每个阶段的特征转换为显著性预测图。
self.outconv 是一个 1×1 卷积层,用于将拼接后的多尺度预测图融合为最终输。

4. 作用与优势

  • 多尺度特征捕获:通过不同深度的 RSU 模块,模型能够同时捕获局部细节和全局上下文信息
  • 特征复用:跳跃连接允许低层特征(如边缘、纹理)与高层特征(如语义信息)直接融合,提升分割精度
  • 灵活性:多输出设计不仅用于深度监督,还能在推理时提供多尺度的预测结果

3.编码器

Encoder 编码器(En_1 到 En_6):逐步提取特征,降低分辨率,增加语义信息。

1. 编码器结构

U²-Net 的编码器由多个 RSU(Residual U-block)模块组成,每个模块负责提取不同尺度的特征。编码器的设计遵循以下原则:

  • 逐步降采样:通过最大池化(MaxPool2d)逐步降低特征图的空间分辨率
  • 逐步增加通道数:每个 RSU 模块的输出通道数逐渐增加,以捕获更丰富的特征

2. 具体实现

1.初始化编码器模块

在 U2NET 类的 init 方法中,定义了编码器的各个阶段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
self.stage1 = RSU7(in_ch, 32, 64) # 输入 -> 64通道
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

self.stage2 = RSU6(64, 32, 128) # 64通道 -> 128通道
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

self.stage3 = RSU5(128, 64, 256) # 128通道 -> 256通道
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

self.stage4 = RSU4(256, 128, 512) # 256通道 -> 512通道
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

self.stage5 = RSU4F(512, 256, 512) # 512通道 -> 512通道
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

self.stage6 = RSU4F(512, 256, 512) # 512通道 -> 512通道

RSU7、RSU6、RSU5、RSU4、RSU4F 是不同深度的 RSU 模块,MaxPool2d 用于降采样,stride=2 表示每次将特征图尺寸减半。

2.前向传播过程

在 forward 方法中,编码器的前向传播过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
hx = x

# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)

输入 x 依次通过 stage1 到 stage6 的 RSU 模块,每个阶段后通过 MaxPool2d 降采样,逐步降低特征图的空间分辨率。

3.作用与优势

  • 多尺度特征提取:编码器通过不同深度的 RSU 模块捕获不同尺度的特征
  • 逐步降采样:通过池化层逐步降低空间分辨率,增加感受野
  • 通道数增加:逐步增加通道数,捕获更丰富的特征信息

4.解码器

Decoder 解码器(De_1 到 De_5):逐步恢复分辨率,并融合低层细节与高层语义。

1. 解码器结构

解码器的设计遵循以下原则:

  • 逐步上采样:通过上采样_upsample_like逐步恢复特征图的空间分辨率
  • 跳跃连接:将编码器对应层的特征与上采样后的特征拼接(torch.cat),形成跳跃连接
  • 多输出:每个解码器阶段都输出一个显著性预测图(side output),用于深度监督

2. 具体实现

1.初始化解码器模块

在 U2NET 类的 init 方法中,定义了解码器的各个阶段:

1
2
3
4
5
self.stage5d = RSU4F(1024, 256, 512) # 1024通道 -> 512通道
self.stage4d = RSU4(1024, 128, 256) # 1024通道 -> 256通道
self.stage3d = RSU5(512, 64, 128) # 512通道 -> 128通道
self.stage2d = RSU6(256, 32, 64) # 256通道 -> 64通道
self.stage1d = RSU7(128, 16, 64) # 128通道 -> 64通道

RSU4F、RSU4、RSU5、RSU6、RSU7 是不同深度的 RSU 模块,输入通道数是编码器对应层通道数的两倍(因为跳跃连接会拼接特征)。

2.前向传播过程

在 forward 方法中,解码器的前向传播过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 从最深层开始上采样
hx6up = _upsample_like(hx6, hx5)

# 解码器路径
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) # 融合 hx6up 和 hx5
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) # 融合 hx5dup 和 hx4
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) # 融合 hx4dup 和 hx3
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) # 融合 hx3dup 和 hx2
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) # 融合 hx2dup 和 hx1

_upsample_like函数用于将特征图上采样到目标尺寸,torch.cat 将上采样后的特征与编码器对应层的特征拼接,形成跳跃连接。

3.多输出与最终融合

解码器的每个阶段都输出一个显著性预测图(side output),这些输出通过上采样调整到相同分辨率,然后拼接并融合为最终输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 多输出
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)

# 最终融合
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

self.side1self.side6是 1×1 卷积层,用于将每个阶段的特征转换为显著性预测图。self.outconv 是一个 1×1 卷积层,用于将拼接后的多尺度预测图融合为最终输出。

4.作用与优势

  • 逐步恢复空间分辨率:通过上采样逐步恢复特征图的空间分辨率
  • 特征复用:跳跃连接允许低层特征(如边缘、纹理)与高层特征(如语义信息)直接融合,提升分割精度
  • 多输出设计:每个解码器阶段都输出一个显著性预测图,用于深度监督

5.总结

U²-Net 是一种基于 U-Net 改进的显著性检测模型,通过独特的 ​双层嵌套U型结构(RSU模块)​​ 和 ​深度监督机制,在图像分割、抠图、人像素描生成等任务中表现出色。以下是核心内容总结:

1.核心特点​

  • RSU模块​:每个模块内部嵌套小型U-Net,结合残差连接,实现多尺度特征提取与高效参数利用
  • 深度监督​:训练时对6个中间层输出和最终融合结果同时计算损失,增强梯度流动与多尺度学习能力
  • 多尺度融合​:通过跳跃连接和上采样整合不同层级的特征,兼顾局部细节与全局语义

​2.网络架构​

  • 编码器(En_1~En_6)​​:逐级降采样,使用不同深度的RSU模块(如RSU7、RSU4F)提取特征
  • 解码器(De_1~De_5)​​:逐级上采样并融合编码器特征,通过跳跃连接保留细节
  • 输出层​:生成6个中间预测图和1个融合结果,用于深度监督与最终预测

​3.关键应用​

  • 人像素描生成​:精准保留面部轮廓、发丝等细节,效果优于部分GAN模型
  • 无绿幕抠图​:复杂背景下分离前景,适用于直播、虚拟背景替换
  • 显著性检测​:识别图像中的突出物体(如人、动物),输出二值掩码

​4.优势​

  • 高精度​:RSU模块嵌套设计增强特征复用,显著提升边缘和细节处理能力
  • 轻量化​:通过残差连接和模块化设计平衡性能与参数量
  • 灵活性​:支持多任务扩展(如医学图像分割、视频处理)

6.备注

论文地址:https://arxiv.org/pdf/2005.09007
开源代码地址:https://github.com/xuebinqin/U-2-Net