[pytorch][模型压缩] 通道裁剪后的模型设计——以MobileNet和ResNet为例

说明

模型裁剪可分为两种,一种是稀疏化裁剪,裁剪的粒度为值级别,一种是结构化裁剪,最常用的是通道裁剪。通道裁剪是减少输出特征图的通道数,对应的权值是卷积核的个数。

问题

通常模型裁剪的三个步骤是:1. 判断网络中不重要的通道 2. 删减掉不重要的通道(一般不会立即删,加mask等到评测时才开始删) 3. 将模型导出,然后进行finetue恢复精度。

步骤1,2涉及到非常多的标准和方法,这里不去深究。但是到第3步的时候,怎么导出网络,看似很简单的问题,但是如果碰到resnet这种,是要花费时间研究细节的,而且目前还没有人专门讲这块(实际上是个工程实现问题),下面来详细说说。

以MobileNet为代表的模型

先考虑以mobilenet为代表的模型,mobilenet中包含了一系列块,每块中包含了深度可分离卷积和点卷积,然后整个模型就是一系列block块的堆叠,在目前很多模型中都具有代表性。

首先我们只考虑了模型的11卷积,因为11卷积是最耗算力的,而33卷积的裁剪实际上没有必要,意味可分离意味着将输入特征图的信息丢掉,与其丢掉,那不如在一开始就不去计算要丢掉的那部分,而不计算的那部分正是由前一层的11点卷积得到的,也就是说改变前一层的输出通道,就等同于对当前的可分离卷积的裁剪。

然后问题就只剩下11卷积核的裁剪了,那么需要在模型初始化时设置不同的profile,来实现不同结构的模型裁剪模型,这里代码中的例子是将第一个block中11卷积核的128通道裁剪为64通道,其他通道可依次次类推。

class MobileNet(nn.Module):
    def __init__(self, n_class,  profile='normal', channels=None):
        self.channels = [32, 64, 104, 128, 248, 224, 456, 296, 456, 224, 104, 104, 208, 208]
        if channels:
            self.channsels = channels
        super(MobileNet, self).__init__()

        # original
        if profile == 'normal':
            in_planes = 32
            cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), (1024,1)]
        # 0.5 AMC
        elif profile == '0.5flops':
            in_planes = self.channels[0]
            strides = [1, 2, 1, 2, 1, 2, 1,1,1,1,1, 2,1]
            cfg = list(zip(self.channels[1:], strides))
        else:
            raise NotImplementedError

而在make_layers部分,需要判断当前stride, 有三次stride,每次缩放一倍,默认stride都是1,当然也可以把stride全列举出来,就不用判断了。

    def _make_layers(self, in_planes, cfg, layer):
        layers = []
        for x in cfg:
            out_planes = x if isinstance(x, int) else x[0]
            stride = 1 if isinstance(x, int) else x[1]
            layers.append(layer(in_planes, out_planes, stride))
            in_planes = out_planes
        return nn.Sequential(*layers)

以Resnet50为代表的模型

前面解决了mobileNet的问题,其实也是一个基本网络架构下的裁剪问题,但是目前的网络往往具有复杂的连接,比如像resnet这样,具有残差结构的单元块,这意味着残差部分需要单独处理。

在我压缩完成得到压缩配置之后,先写了简单版本的resnet_pruning版本,这是最朴素的思想:

def ResNet50_Pruning(**kwarg):
    model = ResNet(Bottleneck_Pruning, [3,4,6,3], **kwarg)
    p = 0
    actions = [3, 56, 64, 64, 48, 240, 16, 64, 152, 32, 32, 152, 120, 104, 216, 368, 112, 32, 480, 112, 120, 504, 88, 104, 104, 240, 184, 368, 768, 200, 200, 640, 232, 192, 976, 248, 192, 760, 160, 208, 584, 208, 248, 968, 496, 224, 208, 416, 104, 104, 416, 104, 104, 416]
    for i, m in enumerate(model.modules()):
        if type(m) in (nn.Conv2d, nn.Linear):
            if type(m) == nn.Conv2d and m.groups == m.in_channels:  # depth-wise conv, buffer
                continue
            else: 
                if type(m) is nn.Linear:
                   m.in_features = actions[p]
                else:    
                    m.in_channels = actions[p]
                    m.out_channels = actions[p+1]
                p += 1
    
    return model

将每一层对应的actions都找到,然后令其channel都做出改变,这无疑是最直观的写法,但是由于CONV之后往往带着BN层,当改完CONV之后,你发现BN还是原有的值,这就会使得维度不匹配。当然我有参考了Nvi-Lab的写法,可以先建模型,然后获取压缩的action的裁剪通道,然后重建一个new_conv代替原有的conv,这样写也行,也是一种思路,不过我觉得这样不优雅,而且容易漏东西。

然后我使用了另外一个思路,在建立模型的时候就建立一个裁剪之后的模型,但是由于resnet50有多个blocks,然后每个block中是一个瓶颈结构,这就需要你定位到哪一个block,以及该block中哪一个卷积,这样就存在的一个缺陷就是需要全局的index来标记当前层是第几层。具体的实现如下:

cfg = [3, 56, 64, 64, 48, 240, 16, 64, 152, 32, 32, 152, 120, 104, 216, 368, 112, 32, 480, 112, 120, 504, 88, 104, 104, 240, 184, 368, 768, 200, 200, 640, 232, 192, 976, 248, 192, 760, 160, 208, 584, 208, 248, 968, 496, 224, 208, 416, 104, 104, 416, 104, 104, 416]
shortcut = [1, 10, 22, 40]
block_nums = [3, 4, 6, 3]

def Sample(x, num):
    np.random.seed(2019)
    batch_size, channel_num, height, width = x.data.size()
    channel_index = np.random.choice(channel_num, num)
    x = x[:, channel_index, :, :]
    return x

class Bottleneck(nn.Module):
    def __init__(self, in_planes, planes, stride=1, offset=1):
        super(Bottleneck, self).__init__()

        # pw
        self.conv1 = nn.Conv2d(cfg[offset], cfg[offset+1], kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(cfg[offset+1])
        # dw
        self.conv2 = nn.Conv2d(cfg[offset+1], cfg[offset+2], kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(cfg[offset+2])
        # pw
        self.conv3 = nn.Conv2d(cfg[offset+2], cfg[offset+3], kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(cfg[offset+3])

        self.shortcut = nn.Sequential()
        if offset in shortcut:
            p = shortcut.index(offset)
            self.shortcut = nn.Sequential(
                nn.Conv2d(cfg[offset], cfg[offset+3], kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(cfg[offset+3])
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        s = self.shortcut(x)
        if s.data.size() != out.data.size():
            s = Sample(s, out.data.size(1))
        out += s
        out = F.relu(out)
        return out

可以看到,对于残差那一分支,采用的是sample采样的方法来使得通道数与瓶颈结构相同,之所以不对瓶颈结构中的卷积结果进行采样,是由于这样可以尽可能多地保留输入特征的信息。而瓶颈结构中多了offset参数用以标记当前的卷积的索引。