百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

ThiNet:模型通道结构化剪枝(模型通道图)

toyiye 2024-07-15 01:26 7 浏览 0 评论

作者: LoBob

来源:微信公众号:GiantPandaCV

出处:https://mp.weixin.qq.com/s?__biz=MzA4MjY4NTk0NQ==&mid=2247492572&idx=1&sn=f2cab6d489805d2e4392af81d4e86a1b


ThiNet是一种结构化剪枝,核心思路是找到一个channel的子集可以近似全集,那么就可以丢弃剩下的channel,对应的就是剪掉剩下的channel对应的filters。剪枝算法还是三步剪枝:train-prune-finetune,而且是layer by layer的剪枝。本文由作者授权首发于GiantPandaCV公众号。

0、 介绍

ThiNet是南京大学lamda实验室出品,是ICCV 2017的文章,文章全名《ThiNet: A Filter Level Pruning Method for Deep Neural Network Compression》。

文章的主要思路是:ThiNet是基于filter剪枝,将filter剪枝操作形式化地定义为一个优化问题, 通过下一层的统计信息来指导当前层的剪枝 。如果移除当前层(记为)filter(记为),那么层channel和同样被丢弃;但是如果层的filter的数量不变,则层的输出(也是层的输入)维度不变。也就是发现这样的剪枝对层的输出(也是层的输入)很小影响,作者提出ThiNet剪枝。大白话就是找到一组channel的输出跟全部channel的输出之间的误差最小(采用均方误差/最小二乘法去衡量),那么就可以用这组channel来代替全部channel。

ThiNet剪枝流程:选择channel子集、剪枝、finetune,如下如图

ThiNet剪枝流程

所以算法的实现的核心在于 如何进行channel选择 ,一个channel是一个filter的计算结果,所以二者相互对应。

ThiNet有三个要点:

1、如何进行通道选择,通道的子集与全部通道的全集之间的最小二次乘法误差来做通道重要性判断依据

2、最小化重构误差,相当于给finetune一个初始化卷积核参数

3、对残差网络的剪枝做了适配

一、通道选择(channel selection)

文章采用贪心算法选择channel子集(也就是留下来的filter)。ThiNet是迭代式layer by layer的剪枝。

思路1(正向思路):根据通道重要性判断找到重要的channel,保留下来,然后迭代式剪枝进进行直到压缩率达到预设要求,见公式5。

为什么会有思路1?因为论文的主要思路是,找到一组channel的子集可以近似该层channel的全集,那么就是要找到可以留下来的channel,对应的就是该channel对应的filter;这就是论文的正向思路。

思路1的方法会有一个问题就是,留下来filter的数量是从大到小的变化的,那么按照思路1计算量会很大,因为留下来的filter(记为S)在剪枝一开始的时候要比被移除的filter(记为T)多,所以有

思路2:根据通道重要性判断找到要剪枝(丢弃)的filter,然后迭代式剪枝进行直到压缩率达到预设要求(丢弃一定数量的filter),见公式6。

ThiNet通道重要性判断是:找到一组通道子集近似通道全集的结果。

公式1-3的图片解释

下面公式1-5我是根据论文写的,会有点绕,但对复现这篇论文不是那么重要,核心思路就是上面提到选取一部分channel来近似。

公式1:

公式2:

其中,是第层输入张量,

是从中随机采样得到的,

是卷积核的集合,

是对应的滑动窗口,

是channels,是行,是列,是输出的通道数,是bias

公式1和公式2,可以简化为公式3:

公式1~3是为了简化公式表示的等效变换

基于通道在中是独立的,只取决于,不依赖于,,则有

公式4:,是channel的子集。

公式5是为了最小化留下来的channel的计算结果与原来channel全集的计算结果,即为思路1:

变为公式6,即为思路2:

其中,S ∪ T = {1, 2, . . . , C},S ∩ T = ?,r是压缩率,C是filter数量。

基于贪心算法选择filter子集的算法如图:

贪心算法

def channel_selection(inputs, module, sparsity=0.5, method='greedy'):
    """
    选择当前模块的输入通道,以及高度重要的通道。
    找到可以使现有输出最接近的输入通道。
    
    :param inputs: torch.Tensor, input features map
    :param module: torch.nn.module, layer
    :param sparsity: float, 0 ~ 1 how many prune channel of output of this layer
    :param method: str, how to select the channel
    :return:
        list of int, indices of channel to be selected and pruned
    """
    num_channel = inputs.size(1)  # 通道数
    num_pruned = int(math.ceil(num_channel * sparsity))  #  输入需要删除的通道数
    num_stayed = num_channel - num_pruned

    print('num_pruned', num_pruned)
    if method == 'greedy':
        indices_pruned = []
        while len(indices_pruned) < num_pruned:
            min_diff = 1e10
            min_idx = 0
            for idx in range(num_channel):
                if idx in indices_pruned:
                    continue
                indices_try = indices_pruned + [idx]
                inputs_try = torch.zeros_like(inputs)
                inputs_try[:, indices_try, ...] = inputs[:, indices_try, ...]
                output_try = module(inputs_try)
                output_try_norm = output_try.norm(2) #这里就是公式6
                if output_try_norm < min_diff:
                    min_diff = output_try_norm
                    min_idx = idx
            indices_pruned.append(min_idx)

        indices_stayed = list(set([i for i in range(num_channel)]) - set(indices_pruned))
        
    inputs = inputs.cuda()
    module = module.cuda()

    return indices_stayed, indices_pruned

二、最小化重构误差(Minimize the reconstruction error)

首先先来看看numpy.linalg.lstsq(),是线性矩阵方程的最小二乘法求解。

最小二乘法的公式为:

方法描述linalg.lstsq(a, b[, rcond])返回线性矩阵方程的最小二乘解

numpy.linalg.lstsq(a, b, rcond='warn')
# 将least-squares解返回线性矩阵方程。

其中,是通道选择后的训练样本,可以通过求解

该方法是 每一个通道赋予权重来进一步地减小重构误差。文章说这相当于给finetune一个很好的初始化 。

def weight_reconstruction(module, inputs, outputs, use_gpu=False):
    """
    reconstruct the weight of the next layer to the one being pruned
    :param module: torch.nn.module, module of the this layer
    :param inputs: torch.Tensor, new input feature map of the this layer
    :param outputs: torch.Tensor, original output feature map of the this layer
    :param use_gpu: bool, whether done in gpu
    :return: void
    """
    if module.bias is not None:
        bias_size = [1] * outputs.dim()
        bias_size[1] = -1
        outputs -= module.bias.view(bias_size)  # 从 output feature 中减去 bias (y - b)
    if isinstance(module, torch.nn.Conv2d):
        unfold = torch.nn.Unfold(kernel_size=module.kernel_size, dilation=module.dilation,
                                 padding=module.padding, stride=module.stride)

        unfold.eval()
        x = unfold(inputs)  # 展开到以一个面片(reception field)为列的三维数组 (N * KKC * L (number of fields))
        x = x.transpose(1, 2)  #  transpose (N * KKC * L) -> (N * L * KKC)
        num_fields = x.size(0) * x.size(1)
        x = x.reshape(num_fields, -1)  # x: (NL * KKC)
        y = outputs.view(outputs.size(0), outputs.size(1), -1)  # 将一个特征映射展开为一行数组 (N * C * WH)
        y = y.transpose(1, 2)  #  transpose (N * C * HW) -> (N * HW * C), L == HW
        y = y.reshape(-1, y.size(2))  # y: (NHW * C),  (NHW) == (NL)

        if x.size(0) < x.size(1) or use_gpu is False:
            x, y = x.cpu(), y.cpu()
            
 #上面一系列的reshape的操作是为了调用np.linalg.lstsq这个函数,利用最小二乘法求解weight
    param, residuals, rank, s = np.linalg.lstsq(x.detach().cpu().numpy(),y.detach().cpu().numpy(),rcond=-1)

    param = param[0:x.size(1), :].clone().t().contiguous().view(y.size(1), -1)
    if isinstance(module, torch.nn.Conv2d):
        param = param.view(module.out_channels, module.in_channels, *module.kernel_size)
    del module.weight
    module.weight = torch.nn.Parameter(param)

三、对于VGG-16的ThiNet剪枝策略

1、对前面10层剪枝力度大,因为前面10层的feature map比较大,FLOPs占据了超过90%

2、全连接层占据了 86.41%的模型参数,所以将其改成global average pooling layer

3、剪枝是layer by layer,每剪完一个layer finetune一个epoch,学习率设为0.001,到最后一层剪完 finetune 12个epoch,学习率设为0.0001.

4、在Imagenet上, VGG更具体的剪枝细节可以看论文4.2部分。

四、对于ResNet的剪枝策略

对于残差块的剪枝,因为有个add的操作,相加时候维度必须保持一致,所以残差块最后一层输出的filter不改变而只剪枝前面两层,如下所示:

ResNet剪枝图

每剪完一个layer finetune一个epoch,固定学习率为0.0001,到最后一层剪完 finetune 9个epoch ,学习率从0.001到0.00001变换,其余的与VGG-16中一样。ResNet更具体细节,请查看论文4.3部分

五、参考链接

原作中文解读:http://www.lamda.nju.edu.cn/luojh/project/ThiNet_ICCV17/ThiNet_ICCV17_CN.html

论文:https://arxiv.org/abs/1707.06342

代码:https://github.com/Roll920/ThiNet

https://github.com/kkeono2/Channel-Pruning-using-Thinet-LASSO-


作者: LoBob

来源:微信公众号:GiantPandaCV

出处:https://mp.weixin.qq.com/s?__biz=MzA4MjY4NTk0NQ==&mid=2247492572&idx=1&sn=f2cab6d489805d2e4392af81d4e86a1b

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码