0


yolov5模型压缩之模型剪枝

目前看来,yolo系列是工程上使用最为广泛的检测模型之一。yolov5检测性能优秀,部署便捷,备受广大开发者好评。但是,当模型在前端运行时,对模型尺寸与推理时间要求苛刻,轻量型模型yolov5s也难以招架。为了提高模型效率,这里与大家分享基于yolov5的模型剪枝方法 github分享连接。

剪枝原理与pipeline

本次使用稀疏训练对channel维度进行剪枝,来自论文Learning Efficient Convolutional Networks Through Network Slimming。其实原理很容易理解,我们知道bn层中存在两个可训练参数

    γ
   
   
    ,
   
   
    β
   
  
  
   \gamma,\beta
  
 
γ,β,输入经过bn获得归一化后的分布。当

 
  
   
    γ
   
   
    ,
   
   
    β
   
  
  
   \gamma,\beta
  
 
γ,β趋于0时,输入相当于乘上了0,那么,该channel上的卷积将只能输出0,毫无意义。因此,我们可以认为剔除这样的冗余channel对模型性能影响甚微。普通网络训练时,由于初始化,

 
  
   
    γ
   
  
  
   \gamma
  
 
γ一般分布在1附近。为了使

 
  
   
    γ
   
  
  
   \gamma
  
 
γ趋于0,可以通过添加L1正则来约束,使得系数稀疏化,论文中将添加

 
  
   
    γ
   
  
  
   \gamma
  
 
γL1正则的训练称为稀疏训练。

在这里插入图片描述整个剪枝的过程如下图所示,首先初始化网络,对bn层的参数添加L1正则并对网络训练。统计网络中的

    γ
   
  
  
   \gamma
  
 
γ,设置剪枝率对网络进行裁剪。最后,将裁减完的网络finetune,完成剪枝工作。

在这里插入图片描述

剪枝细节讲解

1.稀疏训练
上一章介绍了稀疏训练的原理,下面看一下代码是如何实现的。代码如下所示,首先,我们需要设置稀疏系数,稀疏系数对整个网络剪枝性能至关重要,设置太小的系数,

    γ
   
  
  
   \gamma
  
 
γ趋于0的程度不高,无法对网络进行高强度的剪枝,但设置过大,会影响网络性能,大幅降低map。因此,我们需要通过实验找到合适的稀疏系数。

bn层的训练参数包括

    γ
   
   
    ,
   
   
    β
   
  
  
   \gamma,\beta
  
 
γ,β,即代码中的m.weight,m.bias,loss.backward之后,在这两个参数的梯度上添加L1正则的梯度即可。
srtmp = opt.sr *(1-0.9* epoch/epochs)for k, m in model.named_modules():ifisinstance(m, nn.BatchNorm2d)and(k notin ignore_bn_list):
         m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data))# L1
         m.bias.grad.data.add_(opt.sr*10* torch.sign(m.bias.data))# L1

2.网络裁剪
上一步获得稀疏训练后的网络,接下来,我们需要将

    γ
   
  
  
   \gamma
  
 
γ趋于0的channel裁剪掉。首先,统计所有BN层的

 
  
   
    γ
   
  
  
   \gamma
  
 
γ,并对齐排序,找到剪枝率对应的阈值thre。
for i, layer in model.named_modules():ifisinstance(layer, nn.BatchNorm2d):if i notin ignore_bn_list:
                model_list[i]= layer
            # bnw = layer.state_dict()['weight']
    model_list ={k:v for k,v in model_list.items()if k notin ignore_bn_list}
    prune_conv_list =[layer.replace("bn","conv")for layer in model_list.keys()]
    bn_weights = gather_bn_weights(model_list)
    sorted_bn = torch.sort(bn_weights)[0]
    thre_index =int(len(sorted_bn)* opt.percent)
    thre = sorted_bn[thre_index]

然后,根据阈值获取每一bn层的mask,这里加了一些逻辑,目的是让剪枝后的channel保证是4的倍数,即复合前端加速要求。

defobtain_bn_mask(bn_module, thre):

    thre = thre.cuda()
    bn_layer = bn_module.weight.data.abs()
    temp =abs(torch.sort(bn_layer)[0][3::4]- thre)
    thre_temp = torch.sort(bn_layer)[0][3::4][temp.argmin()]ifint(temp.argmin())==0and thre_temp > thre:
        thre =-1else:
        thre = thre_temp
    thre_index =int(bn_layer.shape[0]*0.9)if thre_index %4!=0:
        thre_index -= thre_index %4
    thre_perbn = torch.sort(bn_layer)[0][thre_index -1]if thre_perbn < thre:
        thre =min(thre, thre_perbn)
    mask = bn_module.weight.data.abs().gt(thre).float()return mask

由于,剪枝后的网络与原网络channel不能对齐,因此,我们需要重新定义网络,并解析网络。重构的网络结构需要重新定义,因为需要导入更多的参数。

pruned_yaml["backbone"]=[[-1,1, Conv,[64,6,2,2]],# 0-P1/2[-1,1, Conv,[128,3,2]],# 1-P2/4[-1,3, C3Pruned,[128]],[-1,1, Conv,[256,3,2]],# 3-P3/8[-1,6, C3Pruned,[256]],[-1,1, Conv,[512,3,2]],# 5-P4/16[-1,9, C3Pruned,[512]],[-1,1, Conv,[1024,3,2]],# 7-P5/32[-1,3, C3Pruned,[1024]],[-1,1, SPPFPruned,[1024,5]],# 9]
    pruned_yaml["head"]=[[-1,1, Conv,[512,1,1]],[-1,1, nn.Upsample,[None,2,'nearest']],[[-1,6],1, Concat,[1]],# cat backbone P4[-1,3, C3Pruned,[512,False]],# 13[-1,1, Conv,[256,1,1]],[-1,1, nn.Upsample,[None,2,'nearest']],[[-1,4],1, Concat,[1]],# cat backbone P3[-1,3, C3Pruned,[256,False]],# 17 (P3/8-small)[-1,1, Conv,[256,3,2]],[[-1,14],1, Concat,[1]],# cat head P4[-1,3, C3Pruned,[512,False]],# 20 (P4/16-medium)[-1,1, Conv,[512,3,2]],[[-1,10],1, Concat,[1]],# cat head P5[-1,3, C3Pruned,[1024,False]],# 23 (P5/32-large)[[17,20,23],1, Detect,[nc, anchors]],# Detect(P3, P4, P5)]

yolov5的backbone与neck存在C3结构,C3中存在shortcut,即存在两个卷积相加的形式。为了使网络能够正常add,我们需要对add的两个卷积mask进行merge操作。与此同时,网络存在concate,所以还需要记录concate来自于哪些层以及concate输出的层。

for i,(f, n, m, args)inenumerate(d['backbone']+ d['head']):# from, number, module, args
        m =eval(m)ifisinstance(m,str)else m  # eval stringsfor j, a inenumerate(args):try:
                args[j]=eval(a)ifisinstance(a,str)else a  # eval stringsexcept NameError:pass

        n = n_ =max(round(n * gd),1)if n >1else n  # depth gain
        named_m_base ="model.{}".format(i)if m in[Conv]:
            named_m_bn = named_m_base +".bn"

            bnc =int(maskbndict[named_m_bn].sum())
            c1, c2 = ch[f], bnc
            args =[c1, c2,*args[1:]]
            layertmp = named_m_bn
            if i>0:
                from_to_map[layertmp]= fromlayer[f]
            fromlayer.append(named_m_bn)elif m in[C3Pruned]:
            named_m_cv1_bn = named_m_base +".cv1.bn"
            named_m_cv2_bn = named_m_base +".cv2.bn"
            named_m_cv3_bn = named_m_base +".cv3.bn"
            from_to_map[named_m_cv1_bn]= fromlayer[f]
            from_to_map[named_m_cv2_bn]= fromlayer[f]
            fromlayer.append(named_m_cv3_bn)iflen(args)==1:
                temp_mask = maskbndict[named_m_cv1_bn].bool()| maskbndict[named_m_base +'.m.0.cv2.bn'].bool()
                maskbndict[named_m_cv1_bn], maskbndict[named_m_base +'.m.0.cv2.bn']= temp_mask.float(), temp_mask.float()if n >1:for repeat_ind inrange(1, n):
                        temp_mask |= maskbndict[named_m_base +".m.{}.cv2.bn".format(repeat_ind)].bool()for re_ind inrange(n):
                        maskbndict[named_m_base +".m.{}.cv2.bn".format(re_ind)]= temp_mask
                    maskbndict[named_m_cv1_bn], maskbndict[named_m_base +'.m.0.cv2.bn']= temp_mask.float(), temp_mask.float()

            cv1in = ch[f]
            cv1out =int(maskbndict[named_m_cv1_bn].sum())
            cv2out =int(maskbndict[named_m_cv2_bn].sum())
            cv3out =int(maskbndict[named_m_cv3_bn].sum())
            args =[cv1in, cv1out, cv2out, cv3out, n, args[-1]]
            bottle_args =[]
            chin =[cv1out]

            c3fromlayer =[named_m_cv1_bn]for p inrange(n):
                named_m_bottle_cv1_bn = named_m_base +".m.{}.cv1.bn".format(p)
                named_m_bottle_cv2_bn = named_m_base +".m.{}.cv2.bn".format(p)
                bottle_cv1in = chin[-1]
                bottle_cv1out =int(maskbndict[named_m_bottle_cv1_bn].sum())
                bottle_cv2out =int(maskbndict[named_m_bottle_cv2_bn].sum())
                chin.append(bottle_cv2out)
                bottle_args.append([bottle_cv1in, bottle_cv1out, bottle_cv2out])
                from_to_map[named_m_bottle_cv1_bn]= c3fromlayer[p]
                from_to_map[named_m_bottle_cv2_bn]= named_m_bottle_cv1_bn
                c3fromlayer.append(named_m_bottle_cv2_bn)
            args.insert(4, bottle_args)
            c2 = cv3out
            n =1
            from_to_map[named_m_cv3_bn]=[c3fromlayer[-1], named_m_cv2_bn]elif m in[SPPFPruned]:
            named_m_cv1_bn = named_m_base +".cv1.bn"
            named_m_cv2_bn = named_m_base +".cv2.bn"
            cv1in = ch[f]
            from_to_map[named_m_cv1_bn]= fromlayer[f]
            from_to_map[named_m_cv2_bn]=[named_m_cv1_bn]*4
            fromlayer.append(named_m_cv2_bn)
            cv1out =int(maskbndict[named_m_cv1_bn].sum())
            cv2out =int(maskbndict[named_m_cv2_bn].sum())
            args =[cv1in, cv1out, cv2out,*args[1:]]
            c2 = cv2out

        elif m is nn.BatchNorm2d:
            args =[ch[f]]elif m is Concat:
            c2 =sum(ch[x]for x in f)
            inputtmp =[fromlayer[x]for x in f]
            fromlayer.append(inputtmp)elif m is Detect:
            from_to_map[named_m_base +".m.0"]= fromlayer[f[0]]
            from_to_map[named_m_base +".m.1"]= fromlayer[f[1]]
            from_to_map[named_m_base +".m.2"]= fromlayer[f[2]]
            args.append([ch[x]for x in f])ifisinstance(args[1],int):# number of anchors
                args[1]=[list(range(args[1]*2))]*len(f)elif m is Contract:
            c2 = ch[f]* args[0]**2elif m is Expand:
            c2 = ch[f]// args[0]**2else:
            c2 = ch[f]
            fromtmp = fromlayer[-1]
            fromlayer.append(fromtmp)

        m_ = nn.Sequential(*(m(*args)for _ inrange(n)))if n >1else m(*args)# module
        t =str(m)[8:-2].replace('__main__.','')# module type
        np =sum(x.numel()for x in m_.parameters())# number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        save.extend(x % i for x in([f]ifisinstance(f,int)else f)if x !=-1)# append to savelist
        layers.append(m_)if i ==0:
            ch =[]
        ch.append(c2)return nn.Sequential(*layers),sorted(save), from_to_map

重构并解析网络后,我们需要对解析后的网络填充参数,即找到解析后网络对应于原网络的各层参数,并clone赋值给重构后的网络,代码如下:

for((layername, layer),(pruned_layername, pruned_layer))inzip(model.named_modules(), pruned_model.named_modules()):assert layername == pruned_layername
        ifisinstance(layer, nn.Conv2d)andnot layername.startswith("model.24"):
            convname = layername[:-4]+"bn"if convname in from_to_map.keys():
                former = from_to_map[convname]ifisinstance(former,str):
                    out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4]+"bn"].cpu().numpy())))
                    in_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[former].cpu().numpy())))
                    w = layer.weight.data[:, in_idx,:,:].clone()iflen(w.shape)==3:# remain only 1 channel.
                        w = w.unsqueeze(1)
                    w = w[out_idx,:,:,:].clone()
                    
                    pruned_layer.weight.data = w.clone()
                    changed_state.append(layername +".weight")ifisinstance(former,list):
                    orignin =[modelstate[i+".weight"].shape[0]for i in former]
                    formerin =[]for it inrange(len(former)):
                        name = former[it]
                        tmp =[i for i inrange(maskbndict[name].shape[0])if maskbndict[name][i]==1]if it >0:
                            tmp =[k +sum(orignin[:it])for k in tmp]
                        formerin.extend(tmp)
                    out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4]+"bn"].cpu().numpy())))
                    w = layer.weight.data[out_idx,:,:,:].clone()
                    pruned_layer.weight.data = w[:,formerin,:,:].clone()
                    changed_state.append(layername +".weight")else:
                out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4]+"bn"].cpu().numpy())))
                w = layer.weight.data[out_idx,:,:,:].clone()assertlen(w.shape)==4
                pruned_layer.weight.data = w.clone()
                changed_state.append(layername +".weight")ifisinstance(layer,nn.BatchNorm2d):
            out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername].cpu().numpy())))
            pruned_layer.weight.data = layer.weight.data[out_idx].clone()
            pruned_layer.bias.data = layer.bias.data[out_idx].clone()
            pruned_layer.running_mean = layer.running_mean[out_idx].clone()
            pruned_layer.running_var = layer.running_var[out_idx].clone()
            changed_state.append(layername +".weight")
            changed_state.append(layername +".bias")
            changed_state.append(layername +".running_mean")
            changed_state.append(layername +".running_var")
            changed_state.append(layername +".num_batches_tracked")ifisinstance(layer, nn.Conv2d)and layername.startswith("model.24"):
            former = from_to_map[layername]
            in_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[former].cpu().numpy())))
            pruned_layer.weight.data = layer.weight.data[:, in_idx,:,:]
            pruned_layer.bias.data = layer.bias.data
            changed_state.append(layername +".weight")
            changed_state.append(layername +".bias")

至此,我们完成了剪枝的所有步骤。


本文转载自: https://blog.csdn.net/litt1e/article/details/125818244
版权归原作者 小小小绿叶 所有, 如有侵权,请联系我们删除。

“yolov5模型压缩之模型剪枝”的评论:

还没有评论