Tensorflow實現網絡模型剪枝--model_pruning模塊

部分代碼參考:https://blog.csdn.net/lai_cheng/article/details/90643100

剪枝:

剪枝就是利用某一個準則對某一組或某一個權值置0從而達到將網絡神經元置0以達到稀疏化網絡連接從而加快整個推理過程及縮小模型大小的迭代過程,這個準則有暴力窮盡組合排憂、使用對角 Hessian 逼近計算每個權值的重要性、基於一階泰勒展開的模型代價函數來對權值排序、基於L1絕對值的權值參數大小進行排序、基於在小驗證集上的影響進行分值分配排序等方法,而某一組或某一個網絡權值則可以是整個卷積核、全連接層、卷積核或全連接層上的某個權重參數,剪枝的目的是將冗餘的神經元參數置0減小模型大小(需要特殊的模型存儲方式)減少計算參數(需要某種特殊的硬件計算方式)稀疏化網絡連接加快推理速度。

模型剪枝方法:

model_pruning:模型訓練時剪枝,只需選定需要剪枝的層,對於選中做剪枝的層增加一個二進制掩模(mask)變量,形狀和該層的權值張量形狀完全相同。該掩模決定了哪些權值參與前向計算。掩模更新算法則需要爲 TensorFlow 訓練計算圖注入特殊運算符,對當前層權值按絕對值大小排序,對幅度小於一定門限的權值將其對應掩模值設爲 0。反向傳播梯度也經過掩模,被屏蔽的權值(mask 爲 0)在反向傳播步驟中無法獲得更新量。在保存模型時則可以通過去掉剪枝Ops的方式直接稀疏化權重,這樣就起到了稀疏連接的作用。

官方提供model_pruning例子:

tf.app.flags.DEFINE_string(
    'pruning_hparams', '',
    """Comma separated list of pruning-related hyperparameters""")

with tf.graph.as_default():

  # Create global step variable
  global_step = tf.train.get_or_create_global_step()

  # Parse pruning hyperparameters
  pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

  # Create a pruning object using the pruning specification
  p = pruning.Pruning(pruning_hparams, global_step=global_step)

  # Add conditional mask update op. Executing this op will update all
  # the masks in the graph if the current global step is in the range
  # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
  mask_update_op = p.conditional_mask_update_op()

  # Add summaries to keep track of the sparsity in different layers during training
  p.add_pruning_summaries()

  with tf.train.MonitoredTrainingSession(...) as mon_sess:
    # Run the usual training op in the tf session
    mon_sess.run(train_op)

    # Update the masks by running the mask_update_op
    mon_sess.run(mask_update_op)

一定要保證傳給pruning的global_step是隨着訓練迭代保持增長的,否則不會產生剪枝效果!

全連接層剪枝:

from tensorflow.contrib.model_pruning.python.layers import layers
fc_layer1 = layers.masked_fully_connected(ft, 200)
fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)
prediction = layers.masked_fully_connected(fc_layer2, 10)

卷積層剪枝:

from tensorflow.contrib.model_pruning.python.layers import layers

layers.masked_conv2d(indata,kernel_size=[5,5,channel,outchannel],padding='SAME',activation_fn=nn.relu)

操作步驟:先選定需要剪枝的層,替換成相應代碼,在配置剪枝參數,最後訓練時先run剪枝操作,再run訓練操作。

 

模型剪枝完整代碼

來源:https://blog.csdn.net/lai_cheng/article/details/90643100

1.第71-73行做全連接層剪枝

2.第84-94行配置剪枝參數

3.第136行加入了剪枝的sess run 之後訓練,其他代碼都是常規CNN代碼

利用tensorflow實現LeNet網絡的剪枝

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
import time

class LeNet_Mode():
    """ create LeNet network use tensorflow
        LeNet network structure:
        (conv 5x5 32 ,pool/2)
        (conv 5x5 64, pool/2)
        (fc 100)=>=>(fc classes)
    """
    def conv_layer(self, data, ksize, stride, name, w_biases = False,padding = "SAME"):
        with tf.variable_scope(name,reuse=tf.AUTO_REUSE):
            w_init = tf.contrib.layers.xavier_initializer()
            w = tf.get_variable(name= name,shape= ksize, initializer= w_init)
            biases = tf.Variable(tf.constant(0.0, shape=[ksize[3]], dtype=tf.float32), 'biases')
        if w_biases == False:
            cov = tf.nn.conv2d(input= data, filter= w, strides= stride, padding= padding)
        else:
            cov = tf.nn.conv2d(input= data,filter= w, stride= stride,padding= padding) + biases
        return cov
 
    def pool_layer(self, data, ksize, stride, name, padding= 'VALID'):
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            max_pool =  tf.nn.max_pool(value= data, ksize= ksize, strides= stride,padding= padding)
        return max_pool
 
    def flatten(self,data):
        [a,b,c,d] = data.get_shape().as_list()
        ft = tf.reshape(data,[-1,b*c*d])
        return ft
 
    def fc_layer(self,data,name,fc_dims):
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            data_shape = data.get_shape().as_list()
            w_init = tf.contrib.layers.xavier_initializer()
            w = tf.get_variable(shape=[data_shape[1],fc_dims],name= 'w',initializer=w_init)
            # w = tf.Variable(tf.truncated_normal([data_shape[1], fc_dims], stddev=0.01),'w')
            biases = tf.Variable(tf.constant(0.0, shape=[fc_dims], dtype=tf.float32), 'biases')
            fc = tf.nn.relu(tf.matmul(data,w)+ biases)
        return fc
 
    def finlaout_layer(self,data,name,fc_dims):
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            w_init = tf.contrib.layers.xavier_initializer()
            w = tf.get_variable(shape=[data.shape[1],fc_dims],name= 'w',initializer=w_init)
            biases = tf.Variable(tf.constant(0.0, shape=[fc_dims], dtype=tf.float32), 'biases')
            # fc = tf.nn.softmax(tf.matmul(data,w)+ biases)
            fc = tf.matmul(data,w)+biases
        return fc
 
    def model_bulid(self, height, width, channel,classes):
        x = tf.placeholder(dtype= tf.float32, shape = [None,height,width,channel])
        y = tf.placeholder(dtype= tf.float32 ,shape=[None,classes])
 
        # conv 1 ,if image Nx465x128x1 ,(conv 5x5 32 ,pool/2)
        conv1_1 = tf.nn.relu(self.conv_layer(x,ksize=[5,5,channel,32],stride=[1,1,1,1],padding="SAME",name="conv1_1")) # Nx465x128x1 ==>   Nx465x128x32
        pool1_1 = self.pool_layer(conv1_1,ksize=[1,2,2,1],stride=[1,2,2,1],name="pool1_1") # N*232x64x32
 
        # conv 2,(conv 5x5 32)=>(conv 5x5 64, pool/2)
        conv2_1 = tf.nn.relu(self.conv_layer(pool1_1,ksize=[5,5,32,64],stride=[1,1,1,1],padding="SAME",name="conv2_1"))
        pool2_1 = self.pool_layer(conv2_1,ksize=[1,2,2,1],stride=[1,2,2,1],name="pool2_1") # Nx116x32x128
 
        # Flatten
        ft = self.flatten(pool2_1)
 
        # Dense layer,(fc 100)=>=>(fc classes) and prune optimize
        fc_layer1 = layers.masked_fully_connected(ft, 200)
        fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)
        prediction = layers.masked_fully_connected(fc_layer2, 10)
 
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y))
        #  original Dense layer
        # fc1 = self.fc_layer(ft,fc_dims=100,name="fc1")
        # finaloutput = self.finlaout_layer(fc1,fc_dims=10,name="final")
 
        #  pruning op
        global_step = tf.train.get_or_create_global_step()
        reset_global_step_op = tf.assign(global_step, 0)
        # Get, Print, and Edit Pruning Hyperparameters
        pruning_hparams = pruning.get_pruning_hparams()
        print("Pruning Hyper parameters:", pruning_hparams)
        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 0
        pruning_hparams.end_pruning_step = 250
        pruning_hparams.pruning_frequency = 1
        pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .9
        # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
        p = pruning.Pruning(pruning_hparams, global_step=global_step)
        prune_op = p.conditional_mask_update_op()
 
        # optimize
        LEARNING_RATE_BASE = 0.001
        LEARNING_RATE_DECAY = 0.9
        LEARNING_RATE_STEP = 300
        gloabl_steps = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE
                                                   , gloabl_steps,
                                                   LEARNING_RATE_STEP,
                                                   LEARNING_RATE_DECAY,
                                                   staircase=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            optimize = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step)
 
        # prediction
        prediction_label = prediction
        correct_prediction = tf.equal(tf.argmax(prediction_label,1),tf.argmax(y,1))
        accurary = tf.reduce_mean(tf.cast(correct_prediction,dtype=tf.float32))
        correct_times_in_batch = tf.reduce_mean(tf.cast(correct_prediction,dtype=tf.int32))
 
        return dict(
            x=x,
            y=y,
            optimize=optimize,
            correct_prediction=prediction_label,
            correct_times_in_batch=correct_times_in_batch,
            cost=loss,
            accurary = accurary,
            prune_op = prune_op
        )
 
    def init_sess(self):
        init = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        self.sess = tf.Session()
        self.sess.run(init)
 
    def train_network(self,graph,x_train,y_train):
        # Tensorfolw Adding more and more nodes to the previous graph results in a larger and larger memory footprint
        # reset graph
        # tf.reset_default_graph()
        # prune op
        self.sess.run(graph['prune_op'])
        self.sess.run(graph['optimize'], feed_dict={graph['x']:x_train, graph['y']:y_train})
        # print("cost: ",self.sess.run(graph['cost'],feed_dict={graph['x']:x_train, graph['y']:y_train}))
        # print("accurary: ",self.sess.run(graph['accurary'],feed_dict={graph['x']:x_train, graph['y']:y_train}))
 
    def save_model(self):
        saver = tf.train.Saver()
        save_path = saver.save(self.sess,"save/model.ckpt")
 
    def load_data(self):
        mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
        g = self.model_bulid(28, 28, 1, 10)
        # Build the model first, then initialize it, just once
        start = time.time()
        self.init_sess()
        for epoch in range(30):
            for i in range(1500):
                batch_xs, batch_ys = mnist.train.next_batch(1000)
                batch_xs = np.reshape(batch_xs,[-1,28,28,1])
                # sess.run(g['prune_op'], feed_dict={g['x']:batch_xs, g['y']:batch_ys})
                self.train_network(g,batch_xs,batch_ys)
                print("Train cost accurary print:","cost: ", self.sess.run(g['cost'], feed_dict={g['x']: batch_xs, g['y']: batch_ys}), "accurary: ",
                      self.sess.run(g['accurary'], feed_dict={g['x']: batch_xs, g['y']: batch_ys}))
                if i % 30==0:
                    batch_xs_test, batch_ys_test = mnist.test.next_batch(1000)
                    batch_xs_test = np.reshape(batch_xs_test,[-1,28,28,1])
                    acc = self.sess.run(g['accurary'],feed_dict={g['x']:batch_xs_test, g['y']:batch_ys_test})
                    print("******Test cost accurary print******:","cost: ",self.sess.run(g['cost'],feed_dict={g['x']:batch_xs_test, g['y']:batch_ys_test}),"accurary: ",
                          self.sess.run(g['accurary'],feed_dict={g['x']:batch_xs_test, g['y']:batch_ys_test}))
                    print("Sparsity of layers (should be 0)", self.sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
                    if acc > 0.9:
                        self.save_model()
 
        end = time.time()
        print(end-start,"min times")
 
if __name__ == '__main__':
    LeNet = LeNet_Mode()
    LeNet.load_data()

對剪裁結果查看

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
 
 
model_dir = "save/"
 
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
 
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()
 
for key, val in param_dict.items():
    try:
        print(key, val)
        print_tensors_in_checkpoint_file(ckpt_path, tensor_name=key, all_tensors=False,
                                         all_tensor_names=False)
    except:
        pass

 

 


 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章