关于pytorch语义分割二分类问题的两种做法

形式1:输出为单通道

分析

即网络的输出 output 为 [batch_size, 1, height, width] 形状。其中 batch_szie 为批量大小,1 表示输出一个通道,heightwidth 与输入图像的高和宽保持一致。

在训练时,输出通道数是 1,网络得到的 output 包含的数值是任意的数。给定的 target ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output 不断逼近这个标签,首先会让 output 经过一个sigmoid 函数,使其数值归一化到[0, 1],得到 output1 ,然后让这个 output1target 进行交叉熵计算,得到损失值,反向传播更新网络权重。最终,网络经过学习,会使得 output1 逼近target

训练结束后,网络已经具备让输出的 output 经过转换从而逼近 target 的能力。首先将输出的 output 通过sigmoid 函数,然后取一个阈值(一般设置为0.5),大于阈值则取1反之则取0,从而得到预测图 predict。后续则是一些评估相关的计算。

代码实现

在这个过程中,训练的损失函数为二进制交叉熵损失函数,然后根据输出是否用到了sigmoid有两种可选的pytorch实现方式:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func1 = torch.nn.BCEWithLogitsLoss()
loss = loss_func1(output, target)

当网络最后一层没有使用sigmoid时,需要使用 torch.nn.BCEWithLogitsLoss() ,顾名思义,在这个函数中,拿到output首先会做一个sigmoid操作,再进行二进制交叉熵计算。上面的操作等价于

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(output, target)

当然,你也可以在网络最后一层加上sigmoid操作。从而省去第二行的代码(在预测时也可以省去)。

在预测试时,可用下面的代码实现预测图的生成

output = net(input)  # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
...

即大于0.5的记为1,小于0.5记为0。

形式2:输出为多通道

分析

即网络的输出 output 为 [batch_size, num_class, height, width] 形状。其中 batch_szie 为批量大小,num_class 表示输出的通道数与分类数量一致,heightwidth 与输入图像的高和宽保持一致。

在训练时,输出通道数是 num_class(这里取2),网络得到的 output 包含的数值是任意的数。给定的 target ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output 不断逼近这个标签,首先会让 output 经过一个 softmax 函数,使其数值归一化到[0, 1],得到 output1 ,在各通道中,这个数值加起来会等于1。对于target 他是一个单通道图,首先使用onehot编码,转换成 num_class个通道的图像,每个通道中的取值是根据单通道中的取值计算出来的,例如单通道中的第一个像素取值为1(0<= 1 <=num_class-1,这里num_class=2),那么onehot编码后,在第一个像素的位置上,两个通道的取值分别为0,1。也就是说像素的取值决定了对应序号的通道取1,其他的通道取0,这个非常关键。上面的操作执行完后得到target1,让这个 output1target1 进行交叉熵计算,得到损失值,反向传播更新网路权重。最终,网络经过学习,会使得 output1 逼近target1(在各通道层面上)。

训练结束后,网络已经具备让输出的 output 经过转换从而逼近 target 的能力。计算 output 中各通道每一个像素位置上,取值最大的那个对应的通道序号,从而得到预测图 predict。后续则是一些评估相关的计算。

代码实现

在这个过程中,则可以使用交叉熵损失函数:

output = net(input)  # net的最后一层没有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(output, target)

根据前面的分析,我们知道,正常的output是 [batch_size, num_class, height, width]形状的,而target是[batch_size, height, width]形状的,需要按照上面的分析进行转换才可以计算交叉熵,而在pytorch中,我们不需要进一步做这个处理,直接使用就可以了。

在预测试时,使用下面的代码实现预测图的生成

output = net(input)  # net的最后一层没有使用sigmoid
predict = output.argmax(dim=1)
...

即得到输出后,在通道方向上找出最大值所在的索引号。

小结

总的来说,我觉得第二种方式更值得推广,一方面不用考虑阈值的选取问题;另一方面,该方法同样适用于多类别的语义分割任务,通用性更强。

参考资料

[1]https://blog.csdn.net/longshaonihaoa/article/details/105253553

[2]https://cuijiahua.com/blog/2020/03/dl-16.html