pytorch merge bn (model from mmdetec)

import os
import torch

src_weight = 'epoch_12.pth'
model = torch.load(src_weight)

weight = model['state_dict']
filter_names = []
layer_names = list(weight.keys())
for layer_name in layer_names:
    if 'conv' in layer_name and 'conv2_offset.weight' in layer_name:
        #print(layer_name)
        filter_name = layer_name.replace('conv2_offset', 'bn2')
        filter_names.append(filter_name)

for layer_name in layer_names:
    if 'bn' in layer_name and 'weight' in layer_name:
        if layer_name in filter_names:
            continue
        conv_name = layer_name.replace('bn', 'conv')
        bn_bias_name = layer_name.replace('weight', 'bias')
        bn_running_mean_name = layer_name.replace('weight', 'running_mean')
        bn_running_var_name = layer_name.replace('weight', 'running_var')

        bn_weight = weight[layer_name]
        conv_weight = weight[conv_name]
        bn_bias_weight = weight[bn_bias_name]
        bn_running_mean_weight = weight[bn_running_mean_name]
        bn_running_var_weight = weight[bn_running_var_name]
        var_sqrt = torch.sqrt(bn_running_var_weight + 1e-5)
        mean = bn_running_mean_weight

        beta = bn_weight
        gamma = bn_bias_weight

        conv_bias = mean.new_zeros(mean.shape)
        conv_weight = conv_weight * (beta / var_sqrt).reshape([conv_weight.shape[0], 1, 1, 1])
        conv_bias = (conv_bias - mean) / var_sqrt * beta + gamma
        weight[conv_name] = conv_weight
        conv_bias_name = conv_name.replace('weight', 'bias')
        weight[conv_bias_name] = conv_bias

        num_batches_tracked_name = layer_name.replace('weight', 'num_batches_tracked')
        weight.pop(layer_name)
        weight.pop(bn_bias_name)
        weight.pop(bn_running_mean_name)
        weight.pop(bn_running_var_name)
        weight.pop(num_batches_tracked_name)
    if 'downsample.1' in layer_name and 'weight' in layer_name:
        conv_name = layer_name.replace('downsample.1', 'downsample.0')
        bn_bias_name = layer_name.replace('weight', 'bias')
        bn_running_mean_name = layer_name.replace('weight', 'running_mean')
        bn_running_var_name = layer_name.replace('weight', 'running_var')

        bn_weight = weight[layer_name]
        conv_weight = weight[conv_name]
        bn_bias_weight = weight[bn_bias_name]
        bn_running_mean_weight = weight[bn_running_mean_name]
        bn_running_var_weight = weight[bn_running_var_name]
        var_sqrt = torch.sqrt(bn_running_var_weight + 1e-5)
        mean = bn_running_mean_weight

        beta = bn_weight
        gamma = bn_bias_weight

        conv_bias = mean.new_zeros(mean.shape)
        conv_weight = conv_weight * (beta / var_sqrt).reshape([conv_weight.shape[0], 1, 1, 1])
        conv_bias = (conv_bias - mean) / var_sqrt * beta + gamma
        weight[conv_name] = conv_weight
        conv_bias_name = conv_name.replace('weight', 'bias')
        weight[conv_bias_name] = conv_bias

        num_batches_tracked_name = layer_name.replace('weight', 'num_batches_tracked')
        weight.pop(layer_name)
        weight.pop(bn_bias_name)
        weight.pop(bn_running_mean_name)
        weight.pop(bn_running_var_name)
        weight.pop(num_batches_tracked_name)

for name in weight.keys():
    print(name)

new_model = dict()
new_model['meta'] = model['meta']
new_model['state_dict'] = weight
torch.save(new_model, 'merge_bn.pth')

mmdet/models/backbones/resnet.py  mmdet/models/backbones/resnxet.py 也需要修改相應地方  

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章