基于U-Net++的细胞分割代码实现
下面我们以一个医学图像实例分割任务为例,来介绍在 PyTorch 框架下如何使用 U-Net++ 网络。U-Net++ 是在经典 U-Net 基础上进行改进的语义分割网络,它通过引入密集跳跃连接和深层监督机制,增强了特征融合能力与梯度传播效果,特别适用于医学图像中边界模糊、结构复杂的分割任务。
1.数据预处理
1.数据集介绍
这个数据集是一些细胞图像,我们的目标是做前景背景分离,对每一个细胞做实例分割。数据集有以下特点:
- 每个样本有图像(image)和标签(mask),如下图所示
- 标签(mask),是二值图(0/1),表示像素点是否属于细胞
- 数据集总样本量为 670,样本数量比较少
- 图像(image)尺寸相对也是偏小的
2.数据整合
如上图所示,原始标签是每个细胞都被划分成一个独立 mask,不适合拿来直接做训练。需要将所有单个 mask 合并成一个统一的 mask 图(前景=1,背景=0)。
1 | def main(): |
1.合并掩码
重点讲一下如何合并掩码,如何将某一张图像对应的多个单独掩码文件(每个掩码标注一个目标区域)合并为一个总掩码?
1 | # 初始化一个与原始图像 img 大小相同的二维数组 mask,初始值全为 0 |
合并后如下,图像(image)和标签(mask):
3.数据增强
数据增强我们使用Albumentations,它特别适合用于图像分类、分割、检测等深度学习任务。具有高性能(C++后端加速),高度可配置(支持组合增强),易于与 PyTorch、TensorFlow 等深度学习框架集成等特点。并且支持多种增强操作,包括但不限于:
- 几何变换:旋转、裁剪、缩放、仿射变换、透视变换等
- 颜色变换:亮度、对比度、色调、饱和度、CLAHE、自适应直方图均衡
- 噪声添加:高斯噪声、椒盐噪声、模糊、压缩失真
- 分割/检测支持:自动处理 mask、bboxes 等
1 | # 对训练数据进行增强 |
4.加载数据集
config具体配置参数见3-1
1 | # test_size=0.2 20%的图像将作为验证集,剩下 80% 作为训练集 |
2.U-Net++网络结构
1.VGGBlock
VGG-style 卷积模块,也是图像分割网络(如 U-Net++)中常用的基础构建块,包含了两个连续的卷积层 + BN + ReLU激活,用于提取图像的局部特征。
1 | class VGGBlock(nn.Module): |
2.UNet
一个基于 U-Net++ 架构的语义分割模型(简化版),用于将输入图像逐像素分类成 num_classes 个类别。
1 | class UNet(nn.Module): |
3.NestedUNet
U-Net++(Nested U-Net) 网络,通过更密集的 skip connections(跳跃连接),实现更好的特征融合。
1.命名规则
convX_Y 表示第 X 层、第 Y 次更新,如:
- conv1_0: 第 1 层的初始特征提取
- conv1_1: 第 1 层的第 1 次上采样拼接更新
- conv0_3: 第 0 层的第 3 次融合更新
这些是 U-Net++ 的嵌套跳跃连接结构,越来越密集。
1 | class NestedUNet(nn.Module): |
2.深度监督
深度监督是一种训练策略:在模型的中间层添加辅助输出,并在这些中间输出上也施加监督信号(即计算损失函数),帮助模型更早地获得梯度反馈,从而 加速收敛、提升性能。
其核心思想是,不仅在最终输出处监督模型,还要在中间层也加入监督,防止深层网络训练时梯度消失或收敛缓慢。
3.图示
下图是 U-Net++(Nested U-Net) 网络结构图,展示了它相较于原始 U-Net 所引入的“密集跳跃连接”和“深监督输出”机制。
1.图中每个节点的含义
图中的每个节点用 $X^{i,j}$ 表示,含义如下:
- $i$:代表网络的深度层(下采样层数)
- $j$:代表当前路径上该节点在嵌套结构中的层数(横向扩展层数)
例如: - $X^{0,0}$:最浅层(原图分辨率)第一次卷积输出
- $X^{3,0}$:经过三次下采样后的卷积输出
- $X^{0,3}$:第0层经过3次横向跳跃融合后的特征
2.图中箭头表示含义
右下角图例解释了箭头含义:
- Down-sampling(下采样):粗实线向下箭头
- 如:从 $X^{0,0} → X^{1,0}$,通过 MaxPool 实现
- Up-sampling(上采样):粗实线向上箭头
- 如:从 $X^{1,0} → X^{0,1}$,通过插值上采样(如 bilinear)
- Skip connection(跳跃连接):虚线箭头
- 表示来自不同路径的特征拼接(cat)
3.结构核心:嵌套密集跳跃连接
与原始 U-Net 不同,U-Net++ 每一个横向的节点($X^{i,j},j>0$)不仅融合来自下一级的上采样特征,还融合同级之前所有横向特征(如 $X^{i,0}, X^{i,1}, …$)。这实现了更细致的特征聚合,提升了模型表现力。
4.深监督输出(deep supervision)
如上图中黄色框标注了 4 个输出点:$X^{0,1}、X^{0,2}、X^{0,3}、X^{0,4}$,每个点后面接一个 1×1 卷积层进行通道压缩得到预测,所有输出参与 loss 计算,提升训练稳定性和梯度传播效率。
5.代码与结构图节点对应关系
1 | 图结构节点: 对应代码: |
所有 convX_Y 都是 VGGBlock(即 2 个 3×3 卷积 + BN + ReLU),节点间的数据流动方式严格对应图中的方向和跳跃连接,上采样统一用的是 nn.Upsample(scale_factor=2, mode='bilinear')
。
3.训练
1.训练参数配置
1 | def parse_args(): |
2.损失函数
1.BCEDiceLoss
BCEDiceLoss结合了两种损失函数:
- BCE(Binary Cross Entropy):像素级分类准确性
- Dice Loss:评估预测区域与真实区域的重叠程度,适合前景-背景不平衡问题
1 | class BCEDiceLoss(nn.Module): |
前景像素非常少(比如肿瘤检测、医学分割、道路提取),Dice 可增强区域预测能力,BCE 保证细节与边界预测精度。
2.LovaszHingeLoss
结构感知的分割损失函数,基于 Lovász Hinge Loss 的结构性优化目标:
- 与 IoU(交并比)相关,专门为 非凸、不可微的 mIoU 目标 设计的近似优化方法
- 强调 结构 的正确性(不是单个像素)
1 | class LovaszHingeLoss(nn.Module): |
优势 | 描述 |
---|---|
结构感知 | 优化 IoU 相关指标而非逐像素损失 |
非凸优化 | 用次梯度方法优化非连续目标(如 mIoU) |
表现优秀 | 在结构敏感型分割(如边缘、孔洞等)表现更佳 |
3.对比
损失函数 | 优点 | 适用场景 |
---|---|---|
BCEDiceLoss | 简单有效,能兼顾像素准确与区域重叠 | 医学图像、前景稀疏场景 |
LovaszHingeLoss | 更贴近 IoU 优化目标,更关注结构完整性 | 对结构敏感的任务,如道路提取、器官边界分割等 |
3.优化器
1 | # 1.筛选需要优化的参数 |
Adam 优化器适用:训练不稳定或梯度稀疏(如 Transformer、U-Net)。
SGD 优化器适用:训练稳定、想控制收敛过程精细(如 ResNet、分类任务)。
4.学习率
1 | if config['scheduler'] == 'CosineAnnealingLR': |
1.CosineAnnealingLR
余弦退火,适合训练后期希望缓慢收敛,防止震荡,比如分类、分割、目标检测等任务
2.ReduceLROnPlateau
性能不提升时降低学习率,适合模型容易早停,需要动态响应性能变化,比如医学图像、少样本任务等
3.MultiStepLR
多阶段下降,训练有明显阶段性(如 ImageNet 训练常见策略)
4.ConstantLR
不使用调度器,适合做实验对照,或已知固定学习率效果不错
5.训练流程
1.训练
1 | def train(config, train_loader, model, criterion, optimizer): |
2.验证
1 | def validate(config, val_loader, model, criterion): |
3.主流程
1 | def main(): |
4.验证
1.可视化结果
1 | def plot_examples(datax, datay, model, num_examples=6): |
2.验证
1 | def main(): |
可视化样例如下图:
5.总结
本文主要介绍了在PyTorch框架下使用U-Net++网络进行医学图像实例分割的代码实现:
- 数据预处理
- 数据集特点:670张细胞图像,二值掩码标签(0背景/1细胞),图像尺寸较小
- 关键操作:合并多个独立细胞掩码为统一mask,使用OpenCV进行图像缩放和通道处理
- 数据增强:采用Albumentations库实现旋转/翻转/色彩扰动等增强策略
- U-Net++网络结构
- 改进点:相比经典U-Net增加密集跳跃连接和深度监督机制
- 核心组件:
- VGGBlock:基础卷积模块(3x3卷积+BN+ReLU×2)
- 嵌套解码结构:convX_Y命名规则表示第X层第Y次特征融合
- 深度监督:中间层输出辅助损失,缓解梯度消失
- 模型训练
- 损失函数:BCEDiceLoss(兼顾像素精度和区域重叠)和LovaszHingeLoss(优化IoU指标)
- 训练策略:余弦退火学习率、早停机制、模型保存最佳检查点
- 评估指标:IoU(交并比)作为主要评估标准
- 结果验证
- 可视化对比:并列显示原图、预测结果和真实标注
- 性能评估:批量计算验证集平均IoU,保存预测结果图像
6.备注
环境
- mac: 15.2
- python: 3.12.4
- numpy: 1.26.4
- opencv-python: 4.11.0.86
- albumentations: 2.0.6
资源和代码
https://github.com/keychankc/dl_code_for_blog/tree/main/011_U-net%2B%2B