U²-Net显著性目标检测
1.U²-Net介绍与应用
在 图像分割与U-Net系列模型解析 和 基于U-Net++的细胞分割代码实现 中提到了U-Net系列网络模型,而 U²-Net 虽然是一个U-Net的变体版本,原本用于显著性检测任务,但由于其优异的前景提取能力,逐渐被广泛用于抠图、图像编辑、人像分割等任务中。
1.U²-Net 概述
U²-Net 属于“显著性检测”任务中的网络结构,其核心目标是从图像中识别出前景区域,即显著目标(Salient Object Detection, SOD)。从任务定义来看,它本质上和语义分割非常接近,将图像划分为前景和背景,只是语义标签通常只有两类。
显著性检测(Salient Object Detection)指的是从一幅图像中识别出“最引人注意的”区域,通常即为前景区域。模型的输入是一张图像,输出是一张二值或灰度图,其中白色表示前景(显著区域),黑色表示背景。
例如:一只狗在草地上跑动,狗即为显著对象;一群人站在海滩上,人是前景,海滩是背景。
2.应用
如下图,在肖像素描任务生成中,输入是一张彩色人像图像,输出则像是一幅素描画,纹理和细节都能够得到很好的保留,效果相比很多图像生成或风格转换方法更为自然一些,甚至比一些 GAN 模型生成的还要好,特别是人脸细节部分,U²-Net 在轮廓、皱纹、发丝等区域的提取的也非常准确。
除此之外,U²-Net 在抠图(前景分离)任务中也有不错的效果,无论背景复杂与否,它都能较好地将人物主体从图像中分离出来,这在直播、虚拟背景替换、美颜等应用场景中有着非常广泛的应用价值。例如在某些直播中,主播本来需要一张绿幕才能实现实时背景替换,而利用U²-Net这一类模型就可以做到“无绿幕”抠图,显著提升了便捷性和效果稳定性。
U²-Net 的模型结构不仅能够提取精确的前景,还能为后续处理任务提供干净的语义掩码(mask),这为实时视频处理、特效添加、风格迁移等任务打下了基础,要具备这些能力并非仅仅依赖于“更深的U结构”,而是得益于其独特的设计思想。
2. U²-Net网络架构
U²-Net 的整体结构采用了 Encoder-Decoder(编码器-解码器) 框架,但不同于传统 U-Net,它在每个阶段都使用了一个更复杂的模块——RSU(Residual U-block),也就是双层嵌套U型结构(U-in-U)。
整个网络由:
- 6 个编码模块(En_1 到 En_6)
- 5 个解码模块(De_1 到 De_5)
- 多个深监督输出(Sup1 ~ Sup6)
- 一个最终融合输出(S_fuse)
最终组成一个对称且层层嵌套的 U² 结构。如下图所示
1.RSU 模块
RSU 模块是 U²-Net 的核心模块,全称为 Residual U-shaped block,如上图每个 En_x 和 De_x 模块其实并不是单纯的卷积块,而是一个 RSU 模块,其内部包含一个编码-解码结构(小型 U-Net),然后使用多个空洞卷积扩展感受野,同时保持输出尺寸不变,最终将输入与输出进行残差连接,强化特征传递与复用。
这样做的好处是能有效平衡模型深度与参数量,在保证性能的同时提高了效率,从而增强局部特征提取与上下文信息聚合的能力。
1.基础组件:REBNCONV
RSU 模块的基础单元是 REBNCONV,它由以下部分组成:
- 卷积层:3×3 卷积,支持空洞卷积(dilation)
- 批归一化:
BatchNorm2d
,用于稳定训练 - ReLU激活:
ReLU(inplace=True)
,用于非线性变换
1 | class REBNCONV(nn.Module): |
2.RSU 模块结构
以 RSU7 为例,RSU 模块包含以下部分:
1.编码器路径(Encoder Path)
输入首先通过一个 REBNCONV 层(rebnconvin),然后依次通过多个 REBNCONV 层和最大池化层(MaxPool2d),逐步降低特征图的空间分辨率,同时增加通道数,最后通过一个空洞卷积(dirate=2)进一步提取特征。
2 解码器路径(Decoder Path)
从最深层开始,通过上采样 _upsample_like
和 REBNCONV 层逐步恢复空间分辨率,每一步都会将当前特征与编码器对应层的特征拼接(torch.cat),形成跳跃连接(skip connection)。
3.残差连接
最终输出是解码器路径的最后一层与输入特征(hxin)的残差和(hx1d + hxin),这有助于梯度流动和特征保留。
4.具体实现(以 RSU7 为例)
1 | class RSU7(nn.Module): |
5. RSU 模块的变体
U²-Net 中定义了多种 RSU 模块变体(如 RSU7、RSU6、RSU5、RSU4、RSU4F),它们的区别主要在于:
- 深度:编码器路径的层数(如 RSU7 有7层,RSU6 有6层)
- 通道数:中间层(mid_ch)和输出层(out_ch)的通道数
- 空洞卷积:最深层是否使用空洞卷积(dirate=2)
6. 作用与优势
- 多尺度特征提取:通过编码器-解码器结构和跳跃连接,RSU 模块能够同时捕获局部和全局特征
- 残差学习:残差连接有助于梯度流动,避免梯度消失问题
- 灵活性:不同深度的 RSU 模块可以适应不同复杂度的任务
2.深度监督
深度监督(Deep Supervision)指的是在网络训练时,不仅对最终输出计算损失,而是在中间多个层也加入监督信号,引导模型在多个尺度和层级上学习有意义的特征。
U²-Net 的深度监督(Deep Supervision)机制通过以下步骤实现:
1. 多输出设计
U²-Net 在解码器的每个阶段(stage)都输出一个显著性预测图(side output),这些输出分别对应不同尺度的特征图。模型的前向传播返回了7个输出:
1 | # d0 是最终输出(主输出) |
2. 损失函数定义
用于计算每个输出的二值交叉熵损失(BCE Loss)
1 | def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): |
每个输出(d0 到 d6)都与 ground truth(labels_v)计算 BCE 损失,最终的总损失是所有损失的简单相加(权重相等)。
3. 训练流程中的深度监督
在训练循环中,模型的前向传播返回多个输出,然后通过 muti_bce_loss_fusion
计算损失:
1 | d0, d1, d2, d3, d4, d5, d6 = net(inputs_v) |
4. 作用与优势
- 梯度流动:中间层的输出直接参与损失计算,有助于梯度更好地传递到网络深层,缓解梯度消失问题
- 多尺度监督:不同尺度的输出能够捕获不同层次的特征,提升模型的泛化能力
- 加速收敛:多输出的监督信号能够加速模型收敛
3.多尺度融合
显著性目标可能大小不一、形态复杂。单一尺度难以同时兼顾,所以需要关注边缘、纹理、细节的低层特征和 关注语义、上下文、全局结构的高层特征。
U²-Net 的多尺度融合(Multi-scale Fusion)主要通过以下步骤实现:
1. 多尺度特征提取
U²-Net 的编码器路径(Encoder Path)通过多个 RSU 模块(如 RSU7、RSU6、RSU5 等)逐步提取不同尺度的特征。每个 RSU 模块的输出特征图尺寸逐渐减小,但通道数逐渐增加,从而捕获不同层次的信息。
1 | self.stage1 = RSU7(in_ch, 32, 64) # 输入 -> 64通道 |
2. 解码器路径与跳跃连接
解码器路径(Decoder Path)通过上采样_upsample_like
和跳跃连接(skip connection)逐步恢复空间分辨率,同时融合不同尺度的特征:
1 | # 解码器路径 |
torch.cat((hx6up, hx5), 1)
将上采样后的特征与编码器对应层的特征拼接,形成跳跃连接
每个解码器阶段(stage5d、stage4d 等)进一步处理融合后的特征。
3. 多输出与最终融合
U²-Net 在解码器的每个阶段都输出一个显著性预测图(side output),这些输出通过上采样调整到相同分辨率,然后拼接并融合为最终输出:
1 | # 多输出 |
self.side1 到 self.side6 是 1×1 卷积层,用于将每个阶段的特征转换为显著性预测图。
self.outconv 是一个 1×1 卷积层,用于将拼接后的多尺度预测图融合为最终输。
4. 作用与优势
- 多尺度特征捕获:通过不同深度的 RSU 模块,模型能够同时捕获局部细节和全局上下文信息
- 特征复用:跳跃连接允许低层特征(如边缘、纹理)与高层特征(如语义信息)直接融合,提升分割精度
- 灵活性:多输出设计不仅用于深度监督,还能在推理时提供多尺度的预测结果
3.编码器
Encoder 编码器(En_1 到 En_6):逐步提取特征,降低分辨率,增加语义信息。
1. 编码器结构
U²-Net 的编码器由多个 RSU(Residual U-block)模块组成,每个模块负责提取不同尺度的特征。编码器的设计遵循以下原则:
- 逐步降采样:通过最大池化(MaxPool2d)逐步降低特征图的空间分辨率
- 逐步增加通道数:每个 RSU 模块的输出通道数逐渐增加,以捕获更丰富的特征
2. 具体实现
1.初始化编码器模块
在 U2NET 类的 init 方法中,定义了编码器的各个阶段:
1 | self.stage1 = RSU7(in_ch, 32, 64) # 输入 -> 64通道 |
RSU7、RSU6、RSU5、RSU4、RSU4F 是不同深度的 RSU 模块,MaxPool2d 用于降采样,stride=2 表示每次将特征图尺寸减半。
2.前向传播过程
在 forward 方法中,编码器的前向传播过程如下:
1 | hx = x |
输入 x 依次通过 stage1 到 stage6 的 RSU 模块,每个阶段后通过 MaxPool2d 降采样,逐步降低特征图的空间分辨率。
3.作用与优势
- 多尺度特征提取:编码器通过不同深度的 RSU 模块捕获不同尺度的特征
- 逐步降采样:通过池化层逐步降低空间分辨率,增加感受野
- 通道数增加:逐步增加通道数,捕获更丰富的特征信息
4.解码器
Decoder 解码器(De_1 到 De_5):逐步恢复分辨率,并融合低层细节与高层语义。
1. 解码器结构
解码器的设计遵循以下原则:
- 逐步上采样:通过上采样
_upsample_like
逐步恢复特征图的空间分辨率 - 跳跃连接:将编码器对应层的特征与上采样后的特征拼接(torch.cat),形成跳跃连接
- 多输出:每个解码器阶段都输出一个显著性预测图(side output),用于深度监督
2. 具体实现
1.初始化解码器模块
在 U2NET 类的 init 方法中,定义了解码器的各个阶段:
1 | self.stage5d = RSU4F(1024, 256, 512) # 1024通道 -> 512通道 |
RSU4F、RSU4、RSU5、RSU6、RSU7 是不同深度的 RSU 模块,输入通道数是编码器对应层通道数的两倍(因为跳跃连接会拼接特征)。
2.前向传播过程
在 forward 方法中,解码器的前向传播过程如下:
1 | # 从最深层开始上采样 |
_upsample_like
函数用于将特征图上采样到目标尺寸,torch.cat 将上采样后的特征与编码器对应层的特征拼接,形成跳跃连接。
3.多输出与最终融合
解码器的每个阶段都输出一个显著性预测图(side output),这些输出通过上采样调整到相同分辨率,然后拼接并融合为最终输出:
1 | # 多输出 |
self.side1
到self.side6
是 1×1 卷积层,用于将每个阶段的特征转换为显著性预测图。self.outconv
是一个 1×1 卷积层,用于将拼接后的多尺度预测图融合为最终输出。
4.作用与优势
- 逐步恢复空间分辨率:通过上采样逐步恢复特征图的空间分辨率
- 特征复用:跳跃连接允许低层特征(如边缘、纹理)与高层特征(如语义信息)直接融合,提升分割精度
- 多输出设计:每个解码器阶段都输出一个显著性预测图,用于深度监督
5.总结
U²-Net 是一种基于 U-Net 改进的显著性检测模型,通过独特的 双层嵌套U型结构(RSU模块) 和 深度监督机制,在图像分割、抠图、人像素描生成等任务中表现出色。以下是核心内容总结:
1.核心特点
- RSU模块:每个模块内部嵌套小型U-Net,结合残差连接,实现多尺度特征提取与高效参数利用
- 深度监督:训练时对6个中间层输出和最终融合结果同时计算损失,增强梯度流动与多尺度学习能力
- 多尺度融合:通过跳跃连接和上采样整合不同层级的特征,兼顾局部细节与全局语义
2.网络架构
- 编码器(En_1~En_6):逐级降采样,使用不同深度的RSU模块(如RSU7、RSU4F)提取特征
- 解码器(De_1~De_5):逐级上采样并融合编码器特征,通过跳跃连接保留细节
- 输出层:生成6个中间预测图和1个融合结果,用于深度监督与最终预测
3.关键应用
- 人像素描生成:精准保留面部轮廓、发丝等细节,效果优于部分GAN模型
- 无绿幕抠图:复杂背景下分离前景,适用于直播、虚拟背景替换
- 显著性检测:识别图像中的突出物体(如人、动物),输出二值掩码
4.优势
- 高精度:RSU模块嵌套设计增强特征复用,显著提升边缘和细节处理能力
- 轻量化:通过残差连接和模块化设计平衡性能与参数量
- 灵活性:支持多任务扩展(如医学图像分割、视频处理)
6.备注
论文地址:https://arxiv.org/pdf/2005.09007
开源代码地址:https://github.com/xuebinqin/U-2-Net