Soft Filter Pruning (SFP)——允許更新Pruned Filters的Channel Pruning策略

論文地址:https://arxiv.org/abs/1808.06866

GitHub (PyTorch):https://github.com/he-y/soft-filter-pruning

"Soft Filter Pruning for Accelerating Deep Convolutional Neural Networks"這篇文章首先強調了結構稀疏的優勢,基於結構稀疏的channel pruning不需要特定存儲格式和算法庫的支持,能夠充分利用成熟算法庫或框架以運行剪枝後模型,因而自然地適配PAI-Blade、TensorRT、MNN和AliNPU等DL推理框架或推理芯片。

文章同時提到傳統的"hard filter pruning"依賴於預訓練模型,且獲得pruning mask之後直接刪除pruned filters,結果導致隨着模型容量的減少,推理精度急劇下降,尤其是剪枝比例超過50%的情況下,推理精度的下降將會非常顯著;另外,爲了恢復損失的精度,需要額外的、相對耗時的fine-tuning過程;並且,直接刪除的filters不再接受參數更新 (hard pruned away),顯得簡單粗糙,通常爲了獲得較大的剪枝率,需要多次迭代地實施剪枝、fine-tune操作。

如上圖所示,文章爲此提出了"soft filter pruning (SFP)"策略,允許模型從隨機初始化開始(從預訓練模型開始能獲得更好的效果),並在每個epoch訓練開始之前,將具有較小L2-norm的filters置零,然後更新所有filters(包括未剪枝和已剪枝filters),最終模型收斂以後再把一些不重要的filters(zero-filters)裁剪掉,從而獲得模型容量較高、推理精度較高的正則化、剪枝結果。顯然該策略類似於DSD(Dense-Sparsity-Dense)的正則化、剪枝策略,能夠充分利用每個權重連接(無論是未剪枝和已剪枝的連接)的記憶作用,達到理想的正則化效果,並驅使既定比例的權重係數趨於稀疏化。

Soft Filter Pruning (SFP)策略如上圖所示,主要分爲四個步驟:

  1. filter selection:採用L2-norm(作爲importance衡量準則)以及預先定義的剪枝率Pi,選擇出一些不重要的filters;
  2. filter pruning:在每個epoch訓練開始之前,在全局層面將不重要的filters置零,並允許置零的filters在當前epoch訓練期間接受參數更新(soft-manner,不同於greedy selection),從而更好地平衡每個filter的貢獻;
  3. reconstruction:通過反向傳播更新所有filters,能夠讓pruned model按照與原始模型相同的容量接受參數更新。顯然,置零filters對應的前向輸出將趨於零,導致下一層input filters的梯度也趨於零,經梯度反傳之後,對應的參數更新幅度也會變得很小,而重要filters仍然接受正常的參數更新,從而達到理想的正則化效果;
  4. obtaining compact model:最終正則化、收斂以後,通過裁減掉zero filters可以獲得結構緊湊的網絡模型,同時達到理想的壓縮與加速效果;

實驗部分:文章在Cifar10、ImageNet2012數據集上對Resnet做了測試,獲得了理想的剪枝效果,具體結果見文章。針對端到端網絡,文章基於SFP實施了均勻剪枝策略,而在獲得每一層的pruning mask之後,也很容易設計端到端的非均勻網絡剪枝。

總的來說,SFP剪枝策略首先通過正則化,降低了模型的過擬合風險,獲得了包含一定稀疏度的待剪枝模型,非常適合如Resnet、ResNext和VGG等含BN層或不含BN的CNN網絡的結構性剪枝。另外,模型正則化之後,3D filters的重要性衡量準則可替換爲Taylor Expansion Criteria等,或許能獲得更好的剪枝效果,如下所示(Taylor Expansion Criteria與L2 norm):

for k, (n, m) in enumerate(model.named_modules()):
        # compute importance rank
        if isinstance(m, nn.Conv2d) and (k <= ML):
            if method == 'taylor':
                rank_temp_avg = torch.zeros(m.weight.data.shape[0]).float()
                for i in range(args.iters):
                    activation, grad = acts[i][index], grads[i][index]
                    rank_temp = torch.sum((activation * grad), dim = 0, keepdim=True).\
        				        sum(dim=2, keepdim=True).sum(dim=3, keepdim=True)[0, :, 0, 0].data
                    #rank_temp = torch.abs(rank_temp)
                    rank_temp = torch.abs(rank_temp / float(activation.size(0) * activation.size(2) * activation.size(3)))
                    rank_temp = rank_temp / torch.sqrt(torch.sum(rank_temp * rank_temp))
                    rank_temp_avg += rank_temp
                rank_temp_avg /= args.iters
                rank_temp = rank_temp_avg.cuda()
            elif method == 'l2':
                weight_torch = m.weight.data
                weight_vec = weight_torch.view(weight_torch.size()[0], -1)
                rank_temp = torch.norm(weight_vec, 2, 1)
            rank_dict[k] = rank_temp
            rank_list.append(rank_temp)
            total_pruned += rank_temp.shape[0]
            index += 1

最後,基於SFP的正則化方法,FPGM屬於最新演進的剪枝策略,能夠獲得更好的正則化效果:

https://blog.csdn.net/nature553863/article/details/9776004

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章