推理時 cnn bn 摺疊;基於KWS項目

原創: [email protected]
時間: 2020/06/18

在這裏插入圖片描述在這裏插入圖片描述
在這裏插入圖片描述
魔改的cnn的推理時,將bn摺疊,即在訓練的變量上乘以一個係數從而將bn層在推理時捨去,

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import os
import numpy as np
import tensorflow as tf

path = os.path.dirname(__file__)
sys.path.append(os.path.join(path, '../'))
import models
import input_data

FLAGS = None

def fold_batch_norm(wanted_words, sample_rate, clip_duration_ms,
                           window_size_ms, window_stride_ms,
                           dct_coefficient_count, model_architecture, model_size_info):
  """Creates an audio model with the nodes needed for inference.

  Uses the supplied arguments to create a model, and inserts the input and
  output nodes that are needed to use the graph for inference.

  Args:
    wanted_words: Comma-separated list of the words we're trying to recognize.
    sample_rate: How many samples per second are in the input audio files.
    clip_duration_ms: How many samples to analyze for the audio pattern.
    window_size_ms: Time slice duration to estimate frequencies from.
    window_stride_ms: How far apart time slices should be.
    dct_coefficient_count: Number of frequency bands to analyze.
    model_architecture: Name of the kind of model to generate.
  """
  
  tf.logging.set_verbosity(tf.logging.INFO)
  sess = tf.InteractiveSession()
  words_list = input_data.prepare_words_list(wanted_words.split(','))
  model_settings = models.prepare_model_settings(
      len(words_list), sample_rate, clip_duration_ms, window_size_ms,
      window_stride_ms, dct_coefficient_count)

 
  fingerprint_input = tf.placeholder(
      tf.float32, [None, model_settings['fingerprint_size']], name='fingerprint_input')

  logits = models.create_model(
      fingerprint_input,
      model_settings,
      FLAGS.model_architecture,
      FLAGS.model_size_info,
      is_training=False)

  ground_truth_input = tf.placeholder(
      tf.float32, [None, model_settings['label_count']], name='groundtruth_input')

  predicted_indices = tf.argmax(logits, 1)
  expected_indices = tf.argmax(ground_truth_input, 1)
  correct_prediction = tf.equal(predicted_indices, expected_indices)
  confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices)
  evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  models.load_variables_from_checkpoint(sess, FLAGS.checkpoint)
  saver = tf.train.Saver(tf.global_variables())

  tf.logging.info('Folding batch normalization layer parameters to preceding layer weights/biases')
  #epsilon added to variance to avoid division by zero
  epsilon  = 1e-3 #default epsilon for tf.slim.batch_norm 
  all_variables = [v.name for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]

  weight_list = ['Variable:0' if i == 0 else 'Variable_'+str(i*2)+':0' for i in range(3)]
  biase_list = ['Variable_'+str(2*i+1)+':0' for i in range(3)]
  #get batch_norm mean
  mean_variables = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
                    if 'moving_mean' in v.name]
  for i, mean_var in enumerate(mean_variables):
    mean_name = mean_var.name
    mean_values = sess.run(mean_var)
    variance_name = mean_name.replace('moving_mean','moving_variance')
    variance_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == variance_name][0]
    variance_values = sess.run(variance_var)
    beta_name = mean_name.replace('moving_mean','beta')
    beta_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == beta_name][0]
    beta_values = sess.run(beta_var)
    bias_name = biase_list[i]
    bias_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == bias_name][0]
    bias_values = sess.run(bias_var)
    wt_name = weight_list[i]
    wt_var = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == wt_name][0]
    wt_values = sess.run(wt_var)

    #Update weights
    tf.logging.info('Updating '+wt_name)
    # 獲取帶 BN 的每一個維度
    wt_dim = wt_values.shape[-1]
    # 在每一個維度上進行計算
    if i != 2:
        for l in range(wt_values.shape[3]):
            for k in range(wt_values.shape[2]):
                for j in range(wt_values.shape[1]):
                    for x in range(wt_values.shape[0]):
                        # gamma (scale factor) is 1.0
                        wt_values[x][j][k][l] *= 1.0/np.sqrt(variance_values[l]+epsilon)
    else:
        for l in range(wt_values.shape[1]):
            for k in range(wt_values.shape[0]):
                wt_values[k][l] *= 1.0/np.sqrt(variance_values[l]+epsilon)
    wt_values = sess.run(tf.assign(wt_var,wt_values))

    # Update biases
    tf.logging.info('Updating '+bias_name)
    biase_dim = wt_values.shape[-1]
    for l in range(biase_dim):
        bias_values[l] = (1.0*(bias_values[l]-mean_values[l])/np.sqrt(variance_values[l]+epsilon)) \
                         + beta_values[l]
    bias_values = sess.run(tf.assign(bias_var,bias_values))

  #Write fused weights to ckpt file
  tf.logging.info('Saving new checkpoint at '+FLAGS.checkpoint+'_bnfused')
  saver.save(sess, FLAGS.checkpoint+'_bnfused')



def main(_):

  # Create the model and load its weights.
  fold_batch_norm(FLAGS.wanted_words, FLAGS.sample_rate,
                         FLAGS.clip_duration_ms, FLAGS.window_size_ms,
                         FLAGS.window_stride_ms, FLAGS.dct_coefficient_count,
                         FLAGS.model_architecture, FLAGS.model_size_info)
  

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--data_url',
      type=str,
      # pylint: disable=line-too-long
      default='',
      # default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
      # pylint: enable=line-too-long
      help='Location of speech training data archive on the web.')
  parser.add_argument(
      '--data_dir',
      type=str,
      # default='/tmp/speech_dataset/',
      default='../../data',
      help="""
      Where to download the speech training data to.
      """)
  parser.add_argument(
      '--silence_percentage',
      type=float,
      default=10.0,
      help="""\
      How much of the training data should be silence.
      """)
  parser.add_argument(
      '--unknown_percentage',
      type=float,
      default=10.0,
      help="""\
      How much of the training data should be unknown words.
      """)
  parser.add_argument(
      '--testing_percentage',
      type=int,
      default=10,
      help='What percentage of wavs to use as a test set.')
  parser.add_argument(
      '--validation_percentage',
      type=int,
      default=10,
      help='What percentage of wavs to use as a validation set.')
  parser.add_argument(
      '--sample_rate',
      type=int,
      default=16000,
      help='Expected sample rate of the wavs',)
  parser.add_argument(
      '--clip_duration_ms',
      type=int,
      default=1000,
      help='Expected duration in milliseconds of the wavs',)
  parser.add_argument(
      '--window_size_ms',
      type=float,
      default=40.0,
      help='How long each spectrogram timeslice is',)
  parser.add_argument(
      '--window_stride_ms',
      type=float,
      default=40.0,
      help='How long each spectrogram timeslice is',)
  parser.add_argument(
      '--dct_coefficient_count',
      type=int,
      default=10,
      help='How many bins to use for the MFCC fingerprint',)
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='How many items to train with at once',)
  parser.add_argument(
      '--wanted_words',
      type=str,
      default='yes,no,up,down,left,right,on,off,stop,go,nihaoxr,xrxr',
      help='Words to use (others will be added to an unknown label)',)
  parser.add_argument(
      '--checkpoint',
      type=str,
      default='../train_model/526_cnn/best/cnn_8884.ckpt-13200',
      help='Checkpoint to load the weights from.')
  parser.add_argument(
      '--model_architecture',
      type=str,
      default='cnn2',
      help='What model architecture to use')
  parser.add_argument(
      '--model_size_info',
      type=int,
      nargs="+",
      default=[128,128,128],
      help='Model dimensions - different for various models')

  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

執行:

python3 ./utils/fold_batchnorm_cnn.py \
--data_dir ../data \
--dct_coefficient_count 10 \
--window_size_ms 32 \
--window_stride_ms 20 \
--checkpoint ./train_model/615_cnn_with_32_frame/best/cnn2_9127.ckpt-14000 \
--model_architecture cnn2 \
--model_size_info 28 10 4 1 1 30 10 4 2 1 16 128
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章