基于U-Net++的细胞分割代码实现

下面我们以一个医学图像实例分割任务为例,来介绍在 PyTorch 框架下如何使用 U-Net++ 网络。U-Net++ 是在经典 U-Net 基础上进行改进的语义分割网络,它通过引入密集跳跃连接和深层监督机制,增强了特征融合能力与梯度传播效果,特别适用于医学图像中边界模糊、结构复杂的分割任务。

1.数据预处理

1.数据集介绍

这个数据集是一些细胞图像,我们的目标是做前景背景分离,对每一个细胞做实例分割。数据集有以下特点:

  1. 每个样本有图像(image)和标签(mask),如下图所示
  2. 标签(mask),是二值图(0/1),表示像素点是否属于细胞
  3. 数据集总样本量为 670,样本数量比较少
  4. 图像(image)尺寸相对也是偏小的

2.数据整合

如上图所示,原始标签是每个细胞都被划分成一个独立 mask,不适合拿来直接做训练。需要将所有单个 mask 合并成一个统一的 mask 图(前景=1,背景=0)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def main():  
# 设置输出图像大小为 96×96
img_size = 96
# 读取所有子文件夹,每个子文件夹对应一个图像样本(包含 images 和 masks 子目录)
paths = glob('inputs/stage1_train/*')

# 用于存放缩放后的图像(image)
os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
# 用于存放合并后的掩码图像(mask)
os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)

for i in tqdm(range(len(paths))): # 用tqdm可在控制台显示处理进度
path = paths[i]
# 读取图片
img = cv2.imread(os.path.join(path, 'images', os.path.basename(path) + '.png'))

# 合并掩码
mask = np.zeros((img.shape[0], img.shape[1]))
for mask_path in glob(os.path.join(path, 'masks', '*')):
mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
mask[mask_] = 1

# 图像通道处理
if len(img.shape) == 2: # 灰度图(2D),复制为3通道图像
img = np.tile(img[..., None], (1, 1, 3))
if img.shape[2] == 4: # RGBA(4通道),去掉透明度通道,只保留RGB
img = img[..., :3]

# 将图像和掩码都缩放为 96×96
img = cv2.resize(img, (img_size, img_size))
mask = cv2.resize(mask, (img_size, img_size))

# 保存缩放后的图像和掩码
cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size, os.path.basename(path) + '.png'), img)
cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size, os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))

1.合并掩码

重点讲一下如何合并掩码,如何将某一张图像对应的多个单独掩码文件(每个掩码标注一个目标区域)合并为一个总掩码?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 初始化一个与原始图像 img 大小相同的二维数组 mask,初始值全为 0
# mask 用于存放所有分割区域的合并结果(即所有目标的整体掩码)
mask = np.zeros((img.shape[0], img.shape[1]))

# 遍历路径 path/masks/ 目录下的所有掩码文件(每个文件对应一个实例或目标区域)
# glob 返回所有匹配的文件路径
for mask_path in glob(os.path.join(path, 'masks', '*')):

# 读取当前掩码图像 mask_path,以灰度模式读取
# > 127 将图像二值化,转换为布尔数组,表示掩码中前景(目标)的位置
mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127

# 将所有当前掩码图中为 True(即目标区域)的位置,在总的 mask 中赋值为 1
# 最终 mask 中为 1 的地方表示所有目标区域的合集
mask[mask_] = 1

合并后如下,图像(image)和标签(mask):

3.数据增强

数据增强我们使用Albumentations,它特别适合用于图像分类、分割、检测等深度学习任务。具有高性能(C++后端加速),高度可配置(支持组合增强),易于与 PyTorch、TensorFlow 等深度学习框架集成等特点。并且支持多种增强操作,包括但不限于:

  • 几何变换:旋转、裁剪、缩放、仿射变换、透视变换等
  • 颜色变换:亮度、对比度、色调、饱和度、CLAHE、自适应直方图均衡
  • 噪声添加:高斯噪声、椒盐噪声、模糊、压缩失真
  • 分割/检测支持:自动处理 mask、bboxes 等
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 对训练数据进行增强
train_transform = Compose([
A.RandomRotate90(p=0.5), # 以50%的概率将图像随机旋转90、180或270度
A.HorizontalFlip(p=0.5), # 以50%的概率进行水平翻转
A.VerticalFlip(p=0.5), # 以50%的概率进行垂直翻转
OneOf([
A.HueSaturationValue(p=1), # 对图像的色调、饱和度和亮度进行随机扰动
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1), # 随机调整亮度和对比度
], p=1), # 二选一,只选择其中一个图像增强方式(概率为1,必须执行其一)
A.Resize(config['input_h'], config['input_w']), # 将图像缩放到模型输入所需的高度和宽度
A.Normalize(), # 将图像像素归一化到标准分布
])

# 对验证数据仅做必要的预处理
val_transform = Compose([
A.Resize(config['input_h'], config['input_w']), # 尺寸调整 96 * 96
A.Normalize(), # 归一化
])

4.加载数据集

config具体配置参数见3-1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# test_size=0.2 20%的图像将作为验证集,剩下 80% 作为训练集
# random_state=41 设置随机种子,确保每次运行划分结果一致(可复现)
train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

# 加载训练数据集
train_dataset = Dataset(
img_ids=train_img_ids, # 用于定位每张图像
img_dir=os.path.join('inputs', config['dataset'], 'images'), # 图像所在目录
mask_dir=os.path.join('inputs', config['dataset'], 'masks'), # 掩码文件所在目录
img_ext=config['img_ext'], # 图像扩展名
mask_ext=config['mask_ext'], # 掩码扩展名
num_classes=config['num_classes'], # 类别数量(用于多分类掩码,通常为1表示二分类)
transform=train_transform) # 图像与掩码的增强变换
# 加载验证数据集
val_dataset = Dataset(
img_ids=val_img_ids,
img_dir=os.path.join('inputs', config['dataset'], 'images'),
mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
img_ext=config['img_ext'],
mask_ext=config['mask_ext'],
num_classes=config['num_classes'],
transform=val_transform)

# 训练数据加载器
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config['batch_size'], # 每次训练所取的样本数量
shuffle=True, # 是否打乱数据顺序
num_workers=config['num_workers'], # 读取数据时的线程数
drop_last=True) # 如果最后一个batch不足batch_size,是否丢弃它
# 验证数据加载器
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers'],
drop_last=False)

2.U-Net++网络结构

1.VGGBlock

VGG-style 卷积模块,也是图像分割网络(如 U-Net++)中常用的基础构建块,包含了两个连续的卷积层 + BN + ReLU激活,用于提取图像的局部特征。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class VGGBlock(nn.Module):  
# in_channels: 输入特征图的通道数
# middle_channels: 第一个卷积的输出通道数
# out_channels: 第二个卷积的输出通道数
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
# ReLU 激活函数,inplace=True 节省内存
self.relu = nn.ReLU(inplace=True)
# Conv2d 3×3 卷积层,padding=1 保证输入输出尺寸一致
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
# BatchNorm2d 批标准化,加快训练速度、稳定收敛
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)

def forward(self, x): # 前向传播
out = self.conv1(x) # 第一次卷积
out = self.bn1(out) # 批标准化
out = self.relu(out) # 激活

out = self.conv2(out) # 第二次卷积
out = self.bn2(out) # 批标准化
out = self.relu(out) # 激活

return out

2.UNet

一个基于 U-Net++ 架构的语义分割模型(简化版),用于将输入图像逐像素分类成 num_classes 个类别。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class UNet(nn.Module):  
# num_classes: 最终输出类别数(语义分割中每个像素属于哪个类)
# input_channels: 输入图像的通道数(RGB=3,灰度=1)
def __init__(self, num_classes, input_channels=3, **kwargs):
super().__init__()
# 每一层的通道数配置
nb_filter = [32, 64, 128, 256, 512]

# 编码器部分(下采样路径)
# 图像尺寸逐步减小,通道数逐步增大,用于提取深层语义特征
self.pool = nn.MaxPool2d(2, 2) # 下采样操作
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

# 解码器部分(上采样路径)
# 包含多个上采样 + 特征融合(concat)+ 卷积
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 上采样
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])

# 最终输出层
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


def forward(self, input): # 前向传播
# 编码阶段(下采样):图像尺寸每次减半,特征逐步变深
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x2_0 = self.conv2_0(self.pool(x1_0))
x3_0 = self.conv3_0(self.pool(x2_0))
x4_0 = self.conv4_0(self.pool(x3_0))

# 解码阶段(U-Net++风格)
# 每一层的输出不仅使用其下层的上采样结果,还结合了它本身的特征图,形成一个更密集的 skip-connection 网络结构
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))

# 输出预测图
# output是一个张量,形状为 (B, num_classes, H, W),表示每个像素的类别预测
output = self.final(x0_4)
return output

3.NestedUNet

U-Net++(Nested U-Net) 网络,通过更密集的 skip connections(跳跃连接),实现更好的特征融合。

1.命名规则

convX_Y 表示第 X 层、第 Y 次更新,如:

  • conv1_0: 第 1 层的初始特征提取
  • conv1_1: 第 1 层的第 1 次上采样拼接更新
  • conv0_3: 第 0 层的第 3 次融合更新
    这些是 U-Net++ 的嵌套跳跃连接结构,越来越密集。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class NestedUNet(nn.Module):  
# num_classes: 分割类别数
# input_channels: 输入图像通道数,默认为 3(RGB 图像)
# deep_supervision: 是否启用多层输出作为监督信号(训练时更有效,测试时只用最后一层输出)
def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
super().__init__()

nb_filter = [32, 64, 128, 256, 512]

self.deep_supervision = deep_supervision

self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

# 输出层(支持深监督)
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
# 输出多个不同深度的预测,常用于训练阶段融合损失
return [output1, output2, output3, output4]
else:
output = self.final(x0_4) # 卷积,输出最终预测
return output

2.深度监督

深度监督是一种训练策略:在模型的中间层添加辅助输出,并在这些中间输出上也施加监督信号(即计算损失函数),帮助模型更早地获得梯度反馈,从而 加速收敛提升性能
其核心思想是,不仅在最终输出处监督模型,还要在中间层也加入监督,防止深层网络训练时梯度消失或收敛缓慢。

3.图示

下图是 U-Net++(Nested U-Net) 网络结构图,展示了它相较于原始 U-Net 所引入的“密集跳跃连接”和“深监督输出”机制。

1.图中每个节点的含义

图中的每个节点用 $X^{i,j}$ 表示,含义如下:

  • $i$:代表网络的深度层(下采样层数)
  • $j$:代表当前路径上该节点在嵌套结构中的层数(横向扩展层数)
    例如:
  • $X^{0,0}$:最浅层(原图分辨率)第一次卷积输出
  • $X^{3,0}$:经过三次下采样后的卷积输出
  • $X^{0,3}$:第0层经过3次横向跳跃融合后的特征
2.图中箭头表示含义

右下角图例解释了箭头含义:

  • Down-sampling(下采样):粗实线向下箭头
    • 如:从 $X^{0,0} → X^{1,0}$,通过 MaxPool 实现
  • Up-sampling(上采样):粗实线向上箭头
    • 如:从 $X^{1,0} → X^{0,1}$,通过插值上采样(如 bilinear)
  • Skip connection(跳跃连接):虚线箭头
    • 表示来自不同路径的特征拼接(cat)
3.结构核心:嵌套密集跳跃连接

与原始 U-Net 不同,U-Net++ 每一个横向的节点($X^{i,j},j>0$)不仅融合来自下一级的上采样特征,还融合同级之前所有横向特征(如 $X^{i,0}, X^{i,1}, …$)。这实现了更细致的特征聚合,提升了模型表现力。

4.深监督输出(deep supervision)

如上图中黄色框标注了 4 个输出点:$X^{0,1}、X^{0,2}、X^{0,3}、X^{0,4}$,每个点后面接一个 1×1 卷积层进行通道压缩得到预测,所有输出参与 loss 计算,提升训练稳定性和梯度传播效率。

5.代码与结构图节点对应关系
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
图结构节点:             对应代码:
X^0,0 → self.conv0_0
X^1,0 → self.conv1_0
X^2,0 → self.conv2_0
X^3,0 → self.conv3_0
X^4,0 → self.conv4_0

X^0,1 → self.conv0_1
X^1,1 → self.conv1_1
X^2,1 → self.conv2_1
X^3,1 → self.conv3_1

X^0,2 → self.conv0_2
X^1,2 → self.conv1_2
X^2,2 → self.conv2_2

X^0,3 → self.conv0_3
X^1,3 → self.conv1_3

X^0,4 → self.conv0_4

所有 convX_Y 都是 VGGBlock(即 2 个 3×3 卷积 + BN + ReLU),节点间的数据流动方式严格对应图中的方向和跳跃连接,上采样统一用的是 nn.Upsample(scale_factor=2, mode='bilinear')

3.训练

1.训练参数配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def parse_args():  
parser = argparse.ArgumentParser()

# 模型名称
parser.add_argument('--name', default=None, help='model name: (default: arch+timestamp)')
# 训练轮数
parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run')
# 每个batch的数据量
parser.add_argument('-b', '--batch_size', default=8, type=int, metavar='N', help='mini-batch size (default: 16)')

# model,控制模型结构、是否使用深度监督、输入图像尺寸
parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
choices=ARCH_NAMES,
help='model architecture: ' + ' | '.join(ARCH_NAMES) + ' (default: NestedUNet)')
# 是否使用深度监督(U-Net++ 特性),false,只用最后一层输出
parser.add_argument('--deep_supervision', default=False, type=str2bool)
# 输入图像通道数 3表示彩色图像(RGB)
parser.add_argument('--input_channels', default=3, type=int,
help='input channels')
# 类别数量1 表示二分类(前景 vs 背景)
parser.add_argument('--num_classes', default=1, type=int,
help='number of classes')
# 输入图像宽
parser.add_argument('--input_w', default=96, type=int,
help='image width')
# 输入图像高
parser.add_argument('--input_h', default=96, type=int,
help='image height')

# 损失函数
parser.add_argument('--loss', default='BCEDiceLoss', choices=LOSS_NAMES,
help='loss: ' + ' | '.join(LOSS_NAMES) + ' (default: BCEDiceLoss)')

# 数据集
parser.add_argument('--dataset', default='dsb2018_96', help='dataset name')
# 图像文件后缀
parser.add_argument('--img_ext', default='.png', help='image file extension')
# 掩码的文件后缀
parser.add_argument('--mask_ext', default='.png', help='mask file extension')

# 优化器
parser.add_argument('--optimizer', default='SGD',
choices=['Adam', 'SGD'],
help='loss: ' + ' | '.join(['Adam', 'SGD']) + ' (default: Adam)')
# 初始学习率
parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float, metavar='LR', help='initial learning rate')
# 动量项(SGD 特有)
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
# 权重衰减(正则化)
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
# 是否使用 Nesterov 动量
parser.add_argument('--nesterov', default=False, type=str2bool, help='nesterov')

# 学习率调度器
parser.add_argument('--scheduler', default='CosineAnnealingLR',
choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
# 最小学习率
parser.add_argument('--min_lr', default=1e-5, type=float, help='minimum learning rate')
# 每次降低的倍数,用于某些调度器
parser.add_argument('--factor', default=0.1, type=float)
# scheduler 的耐心(多少次不提升再调整学习率)
parser.add_argument('--patience', default=2, type=int)
# 在哪些 epoch 降低学习率
parser.add_argument('--milestones', default='1,2', type=str)
# 学习率衰减系数
parser.add_argument('--gamma', default=2 / 3, type=float)
# 早停 -1 表示不启用
parser.add_argument('--early_stopping', default=-1, type=int, metavar='N', help='early stopping (default: -1)')
# 数据加载线程数
parser.add_argument('--num_workers', default=0, type=int)

config = parser.parse_args()

return config

2.损失函数

1.BCEDiceLoss

BCEDiceLoss结合了两种损失函数:

  • BCE(Binary Cross Entropy):像素级分类准确性
  • Dice Loss:评估预测区域与真实区域的重叠程度,适合前景-背景不平衡问题
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class BCEDiceLoss(nn.Module):  
def __init__(self):
super().__init__()

def forward(self, input, target):
"""
:param input: 模型输出 logits
:param target: ground truth 掩码(0/1)
:return:
"""
# 1.BCE Loss
# 直接使用带logits的BCE,内部自动做了sigmoid
bce = F.binary_cross_entropy_with_logits(input, target)
smooth = 1e-5

# 2.Dice Loss
input = torch.sigmoid(input) # 激活sigmoid,转换为概率
# reshape 展平
num = target.size(0)
input = input.view(num, -1)
target = target.view(num, -1)
# 计算Dice系数,Dice系数表示预测和目标之间的重叠程度,越大越好
intersection = (input * target)
dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
dice = 1 - dice.sum() / num

return 0.5 * bce + dice

前景像素非常少(比如肿瘤检测、医学分割、道路提取),Dice 可增强区域预测能力,BCE 保证细节与边界预测精度。

2.LovaszHingeLoss

结构感知的分割损失函数,基于 Lovász Hinge Loss 的结构性优化目标:

  • 与 IoU(交并比)相关,专门为 非凸、不可微的 mIoU 目标 设计的近似优化方法
  • 强调 结构 的正确性(不是单个像素)
1
2
3
4
5
6
7
8
9
10
class LovaszHingeLoss(nn.Module):  
def __init__(self):
super().__init__()

def forward(self, input, target):
input = input.squeeze(1)
target = target.squeeze(1)
# 计算每张图的结构性错误
loss = lovasz_hinge(input, target, per_image=True)
return loss
优势 描述
结构感知 优化 IoU 相关指标而非逐像素损失
非凸优化 用次梯度方法优化非连续目标(如 mIoU)
表现优秀 在结构敏感型分割(如边缘、孔洞等)表现更佳

3.对比

损失函数 优点 适用场景
BCEDiceLoss 简单有效,能兼顾像素准确与区域重叠 医学图像、前景稀疏场景
LovaszHingeLoss 更贴近 IoU 优化目标,更关注结构完整性 对结构敏感的任务,如道路提取、器官边界分割等

3.优化器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 1.筛选需要优化的参数
# 只把可训练的参数传给优化器(忽略被冻结的参数)
params = filter(lambda p: p.requires_grad, model.parameters())

# 2.根据配置创建优化器
# 读取config字典,按配置创建对应的优化器实例
if config['optimizer'] == 'Adam':
# 自适应学习率的 Adam 优化器
optimizer = optim.Adam(
params, lr=config['lr'], weight_decay=config['weight_decay'])
elif config['optimizer'] == 'SGD':
# 经典的随机梯度下降优化器
# momentum:动量参数
# nesterov:是否使用 Nesterov 动量加速(提前看一步梯度)
# weight_decay:L2 正则项
optimizer = optim.SGD(params,
lr=config['lr'],
momentum=config['momentum'],
nesterov=config['nesterov'],
weight_decay=config['weight_decay'])
else:
raise NotImplementedError

Adam 优化器适用:训练不稳定或梯度稀疏(如 Transformer、U-Net)。
SGD 优化器适用:训练稳定、想控制收敛过程精细(如 ResNet、分类任务)。

4.学习率

1
2
3
4
5
6
7
8
9
10
11
12
13
if config['scheduler'] == 'CosineAnnealingLR':  
# 余弦下降:学习率在训练过程中像余弦曲线一样先慢慢下降,接近最后阶段时趋于 min_lr
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
elif config['scheduler'] == 'ReduceLROnPlateau':
# 当模型某个评价指标(如验证 loss)在若干轮内不再下降时,就减小学习率
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=True, min_lr=config['min_lr'])
elif config['scheduler'] == 'MultiStepLR':
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
elif config['scheduler'] == 'ConstantLR':
scheduler = None
else:
raise NotImplementedError

1.CosineAnnealingLR

余弦退火,适合训练后期希望缓慢收敛,防止震荡,比如分类、分割、目标检测等任务

2.ReduceLROnPlateau

性能不提升时降低学习率,适合模型容易早停,需要动态响应性能变化,比如医学图像、少样本任务等

3.MultiStepLR

多阶段下降,训练有明显阶段性(如 ImageNet 训练常见策略)

4.ConstantLR

不使用调度器,适合做实验对照,或已知固定学习率效果不错

5.训练流程

1.训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def train(config, train_loader, model, criterion, optimizer):  

# 1.初始化状态
# 记录每轮的平均 loss 和 IOU
avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
model.train() # 设置为训练模式
pbar = tqdm(total=len(train_loader)) # 进度条,显示训练过程

# 2.遍历训练数据
for input, target, _ in train_loader: # 每次读取一个批次(image mask)

# 3.处理设备
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
else:
device = torch.device("cpu")
input = input.to(device)
target = target.to(device)

# 4. 正向传播 & 计算损失 + IoU
if config['deep_supervision']:
outputs = model(input) # 多输出用于深度监督
loss = 0
for output in outputs:
loss += criterion(output, target)
loss /= len(outputs)
iou = iou_score(outputs[-1], target)
# model(input) 返回多个输出(如 x0_1, x0_2, …, x0_4)
# 每个输出都参与损失计算,然后平均
# 只用最后一个输出计算 IoU(最深层最准确)
else:
output = model(input)
loss = criterion(output, target)
iou = iou_score(output, target)
# 正常的单输出模型,直接计算损失和IoU

# 5. 梯度更新
optimizer.zero_grad() # 清空旧梯度
loss.backward() # 反向传播
optimizer.step() # 执行一次优化器更新

# 6.更新统计 & 展示进度
avg_meters['loss'].update(loss.item(), input.size(0))
avg_meters['iou'].update(iou, input.size(0))
# 每个 batch 结束后更新累计 loss 与 iou 的加权平均
postfix = OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg),])
pbar.set_postfix(postfix)
pbar.update(1)
pbar.close()

# 7. 返回训练结果
# 平均 loss 和 IoU,用于日志记录、调度器调整、验证等
return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)])

2.验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def validate(config, val_loader, model, criterion):  

# 1.初始化两个累加器:分别记录本轮平均损失 与 IoU
avg_meters = {'loss': AverageMeter(),
'iou': AverageMeter()}

# 2.把模型切到评估模式:BatchNorm使用全局均值/方差 Dropout关闭随机丢弃
model.eval()

# 3.不计算梯度,节省显存
with torch.no_grad():
pbar = tqdm(total=len(val_loader))

# 4.每次读取一个批次(image mask)
for input, target, _ in val_loader:
# 5.处理设备
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
else:
device = torch.device("cpu")
input = input.to(device)
target = target.to(device)

# 6.前向推理与损失 / IoU 计算
if config['deep_supervision']:
outputs = model(input)
loss = 0
for output in outputs:
# 对每个输出单独计算损失,再求平均
loss += criterion(output, target)
loss /= len(outputs)
# 只用最后一级 outputs[-1] 计算 IoU(最精细)
iou = iou_score(outputs[-1], target)
else:
output = model(input)
loss = criterion(output, target)
iou = iou_score(output, target)

# 7.指标累积
avg_meters['loss'].update(loss.item(), input.size(0))
avg_meters['iou'].update(iou, input.size(0))

postfix = OrderedDict([
('loss', avg_meters['loss'].avg),
('iou', avg_meters['iou'].avg),
])
pbar.set_postfix(postfix)
pbar.update(1)
pbar.close()

# 以OrderedDict形式返回本epoch的 验证损失 与 验证 IoU
return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)])

3.主流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def main():  
# 获取参数字典
config = vars(parse_args())

# 自动生成模型名
...

# 保存配置文件
...

# 初始化损失函数
...

# 创建模型
...

# 选择优化器
...

# 配置学习率调度器
...

# 加载图像ID,并划分训练/验证集
...

# 数据增强策略
...

# 加载数据集
...

# 日志记录结构
log = OrderedDict([
('epoch', []),
('lr', []),
('loss', []),
('iou', []),
('val_loss', []),
('val_iou', []),
])

# 1.初始化参数
best_iou = 0 # 记录迄今为止验证集上的最佳 IoU,用于模型保存判断
trigger = 0 # - 记录连续验证集没有提升的 epoch 数,用于早停

# 2.主训练循环
# 循环执行多个 epoch,每轮执行完整的训练-验证流程
for epoch in range(config['epochs']):
print('Epoch [%d/%d]' % (epoch, config['epochs']))

# 3. 每轮训练 & 验证
train_log = train(config, train_loader, model, criterion, optimizer)
val_log = validate(config, val_loader, model, criterion)

# 4.学习率调度器更新
if config['scheduler'] == 'CosineAnnealingLR': # 基于 epoch 数变化
scheduler.step()
elif config['scheduler'] == 'ReduceLROnPlateau': # 根据 val loss 变化调整
scheduler.step(val_log['loss'])

# 5. 打印指标
print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))

# 6.日志记录
log['epoch'].append(epoch)
log['lr'].append(config['lr'])
log['loss'].append(train_log['loss'])
log['iou'].append(train_log['iou'])
log['val_loss'].append(val_log['loss'])
log['val_iou'].append(val_log['iou'])
pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False)

trigger += 1

# 7.保存最佳模型
if val_log['iou'] > best_iou:
torch.save(model.state_dict(), 'models/%s/model.pth' %
config['name'])
best_iou = val_log['iou']
print("=> saved best model")
trigger = 0

# 8.早停判断
if 0 <= config['early_stopping'] <= trigger:
print("=> early stopping")
break

# 9. 清理缓存
torch.cuda.empty_cache()

4.验证

1.可视化结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def plot_examples(datax, datay, model, num_examples=6):  
# 创建画布,准备绘制num_examples行、每行3列的图像(原图、预测、真值)
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18, 4 * num_examples))
m = datax.shape[0] # 数据集样本总数(datax 是一个 batch 或完整验证集)

for row_num in range(num_examples):
# 随机选取一个图像索引
image_indx = np.random.randint(m)

# 获取模型预测结果:输入的是一个图像 batch(1 张),
# 输出后 squeeze 成单张,再转为 NumPy
image_arr = model(datax[image_indx:image_indx + 1]).squeeze(0).detach().cpu().numpy()

# 原始图像
ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])
ax[row_num][0].set_title("Orignal Image")

# 模型预测图像
ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0, :, :].astype(int)))
ax[row_num][1].set_title("Segmented Image localization")

# Ground Truth 掩码图像
ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])
ax[row_num][2].set_title("Target image")

# 绘制
plt.show()

2.验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def main():  
args = parse_args()

# 1. 加载配置文件
with open('models/%s/config.yml' % args.name, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)

# 打印配置信息,便于调试查看
print('-' * 20)
for key in config.keys():
print('%s: %s' % (key, str(config[key])))
print('-' * 20)

# 2. 初始化环境,启用 cuDNN 自动优化
cudnn.benchmark = True

# 3. 创建模型
print("=> creating model %s" % config['arch'])
model = archs.__dict__[config['arch']](config['num_classes'],
config['input_channels'],
config['deep_supervision'])
# 加载到 CUDA(如可用)
if torch.cuda.is_available():
model = model.cuda()
else:
device = torch.device("cpu")
model.to(device)

# 4.加载数据集,读取所有图像ID(去掉扩展名)
img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

# 拆分验证集,随机划分出 20% 的验证集
_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

# 5.加载模型参数
model.load_state_dict(torch.load('models/%s/model.pth' % config['name']))
model.eval() # 设置为评估模式

# 6. 定义验证集增强,定义验证集图像增强(大小缩放 + 归一化)
val_transform = Compose([
A.Resize(config['input_h'], config['input_w']),
A.Normalize(),
])

# 加载验证集数据集
val_dataset = Dataset(
img_ids=val_img_ids,
img_dir=os.path.join('inputs', config['dataset'], 'images'),
mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
img_ext=config['img_ext'],
mask_ext=config['mask_ext'],
num_classes=config['num_classes'],
transform=val_transform)
# 构建数据加载器
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers'],
drop_last=False)

# 7. 评估模型 初始化 IoU 评估器
avg_meter = AverageMeter()

# 创建输出目录
for c in range(config['num_classes']):
os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)

# 8. 模型推理
with torch.no_grad(): # 验证阶段不计算梯度,节省显存
for input, target, meta in tqdm(val_loader, total=len(val_loader)):
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
else:
device = torch.device("cpu")
input = input.to(device)
target = target.to(device)

# 模型前向传播
if config['deep_supervision']:
output = model(input)[-1] # 深度监督则使用最后一个输出
else:
output = model(input)

# 计算 IoU 得分并更新平均器
iou = iou_score(output, target)
avg_meter.update(iou, input.size(0))

# 输出后处理(Sigmoid + 转为 numpy)
output = torch.sigmoid(output).cpu().numpy()

# 将预测结果保存为图像
for i in range(len(output)):
for c in range(config['num_classes']):
cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'), (output[i, c] * 255).astype('uint8'))
# 9. 打印验证结果 平均 IoU
print('IoU: %.4f' % avg_meter.avg)

# 10.可视化样例
plot_examples(input, target, model, num_examples=3)

# 11.释放显存
torch.cuda.empty_cache()

可视化样例如下图:

5.总结

本文主要介绍了在PyTorch框架下使用U-Net++网络进行医学图像实例分割的代码实现:

  1. 数据预处理
    • 数据集特点:670张细胞图像,二值掩码标签(0背景/1细胞),图像尺寸较小
    • 关键操作:合并多个独立细胞掩码为统一mask,使用OpenCV进行图像缩放和通道处理
    • 数据增强:采用Albumentations库实现旋转/翻转/色彩扰动等增强策略
  2. U-Net++网络结构
    • 改进点:相比经典U-Net增加密集跳跃连接和深度监督机制
    • 核心组件:
      • VGGBlock:基础卷积模块(3x3卷积+BN+ReLU×2)
      • 嵌套解码结构:convX_Y命名规则表示第X层第Y次特征融合
    • 深度监督:中间层输出辅助损失,缓解梯度消失
  3. 模型训练
    • 损失函数:BCEDiceLoss(兼顾像素精度和区域重叠)和LovaszHingeLoss(优化IoU指标)
    • 训练策略:余弦退火学习率、早停机制、模型保存最佳检查点
    • 评估指标:IoU(交并比)作为主要评估标准
  4. 结果验证
    • 可视化对比:并列显示原图、预测结果和真实标注
    • 性能评估:批量计算验证集平均IoU,保存预测结果图像

6.备注

环境

  • mac: 15.2
  • python: 3.12.4
  • numpy: 1.26.4
  • opencv-python: 4.11.0.86
  • albumentations: 2.0.6

资源和代码

https://github.com/keychankc/dl_code_for_blog/tree/main/011_U-net%2B%2B