Tensorflow中float32模型強制轉爲float16半浮點模型 1 加載pb模型 2 重寫BatchNorm 3 Graph轉換 4 完整的代碼

最近看到一個巨牛的人工智能教程,分享一下給大家。教程不僅是零基礎,通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。平時碎片時間可以當小說看,【點這裏可以去膜拜一下大神的“小說”】

在Tensorflow框架訓練完成後,部署模型時希望對模型進行壓縮。一種方案是前面文字介紹的方法《【Ubuntu】Tensorflow對訓練後的模型做8位(uint8)量化轉換》。另一種方法是半浮點量化,今天我們主要介紹如何通過修改Tensorflow的pb文件中的計算節點和常量(const),將float32數據類型的模型大小壓縮減半爲float16數據類型的模型。

1 加載pb模型

封裝函數,加載pb模型:

def load_graph(model_path):
    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        if model_path.endswith("pb"):
            with open(model_path, "rb") as f:
                graph_def.ParseFromString(f.read())
        else:
            with open(model_path, "r") as pf:
                text_format.Parse(pf.read(), graph_def)
        tf.import_graph_def(graph_def, name="")
        sess = tf.Session(graph=graph)
        ops=graph.get_operations()
        for op in ops:
            print(op.name)
        return sess

2 重寫BatchNorm

由於BatchNorm對精度比較敏感,需要保持float32類型,因此BatchNorm需要特殊處理。

#用FusedBatchNormV2替換FusedBatchNorm,以保證反向梯度下降計算時使用的是float
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): 
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT
    new_node = graph_def.node.add()
    new_node.op = "FusedBatchNormV2"
    new_node.name = node.name
    new_node.input.extend(node.input)
    new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
    for attr in list(node.attr.keys()):
        if attr == "T":
            node.attr[attr].type = dtype
        new_node.attr[attr].CopyFrom(node.attr[attr])
    print("rewrite fused_batch_norm done!")

3 Graph轉換

重新構造graph,參數從原始pb的graph中拷貝,並轉爲float16


def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
    #生成新的圖數據類型
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT

    #加載需要轉換的模型
    source_sess = load_graph(model_path)
    source_graph_def = source_sess.graph.as_graph_def()
    #創建新的模圖對象
    target_graph_def = graph_pb2.GraphDef()
    target_graph_def.versions.CopyFrom(source_graph_def.versions)
    #對加載的模型遍歷計算節點
    for node in source_graph_def.node:
        # 對FusedBatchNorm計算節點替換爲FusedBatchNormV2
        if node.op == "FusedBatchNorm":
            rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
            continue
        # 複製計算節點
        new_node = target_graph_def.node.add()
        new_node.op = node.op
        new_node.name = node.name
        new_node.input.extend(node.input)

        #對attrs屬性進行復制,attrs屬性主要關注
        attrs = list(node.attr.keys())
        # BatchNorm屬性保持不變
        if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
            for attr in attrs:
                new_node.attr[attr].CopyFrom(node.attr[attr])
            continue
        # 除了BatchNorm以外其他計算節點的屬性單獨
        for attr in attrs:
            # 對指定的計算節點保持不變
            if node.name in keep_fp32_node_name:
                new_node.attr[attr].CopyFrom(node.attr[attr])
                continue
            #將Float類型修改爲設置的目標類型
            if node.attr[attr].type == types_pb2.DT_FLOAT:
                # modify node dtype
                node.attr[attr].type = dtype
                
            #重點關注value,weights都是保存在value屬性中
            if attr == "value":
                tensor = node.attr[attr].tensor
                if tensor.dtype == types_pb2.DT_FLOAT:
                    # if float_val exists
                    if tensor.float_val:
                        float_val = tf.make_ndarray(node.attr[attr].tensor)
                        new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
                        continue
                    # if tensor content exists
                    if tensor.tensor_content:
                        tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                        tensor_weights = tf.make_ndarray(tensor)
                        # reshape tensor
                        tensor_weights = np.reshape(tensor_weights, tensor_shape)
                        tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
                        new_node.attr[attr].tensor.CopyFrom(tensor_proto)
                        continue
            new_node.attr[attr].CopyFrom(node.attr[attr])
    # transform graph
    if output_names:
        if not input_name:
            input_name = []
        transforms = ["strip_unused_nodes"]
        target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
    # write graph_def to model
    tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
    print("Converting done ...")

4 完整的代碼

import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import numpy as np

# object detection api input and output nodes
input_name = "input_tf"
output_names = ["output:0"]
keep_fp32_node_name = []

def load_graph(model_path):
    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        if model_path.endswith("pb"):
            with open(model_path, "rb") as f:
                graph_def.ParseFromString(f.read())
        else:
            with open(model_path, "r") as pf:
                text_format.Parse(pf.read(), graph_def)
        tf.import_graph_def(graph_def, name="")
        sess = tf.Session(graph=graph)
        ops=graph.get_operations()
        for op in ops:
            print(op.name)
        return sess

#用FusedBatchNormV2替換FusedBatchNorm,以保證反向梯度下降計算時使用的是float
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): 
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT
    new_node = graph_def.node.add()
    new_node.op = "FusedBatchNormV2"
    new_node.name = node.name
    new_node.input.extend(node.input)
    new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
    for attr in list(node.attr.keys()):
        if attr == "T":
            node.attr[attr].type = dtype
        new_node.attr[attr].CopyFrom(node.attr[attr])
    print("rewrite fused_batch_norm done!")

def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
    #生成新的圖數據類型
    if target_type == 'fp16':
        dtype = types_pb2.DT_HALF
    elif target_type == 'fp64':
        dtype = types_pb2.DT_DOUBLE
    else:
        dtype = types_pb2.DT_FLOAT

    #加載需要轉換的模型
    source_sess = load_graph(model_path)
    source_graph_def = source_sess.graph.as_graph_def()
    #創建新的模圖對象
    target_graph_def = graph_pb2.GraphDef()
    target_graph_def.versions.CopyFrom(source_graph_def.versions)
    #對加載的模型遍歷計算節點
    for node in source_graph_def.node:
        # 對FusedBatchNorm計算節點替換爲FusedBatchNormV2
        if node.op == "FusedBatchNorm":
            rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
            continue
        # 複製計算節點
        new_node = target_graph_def.node.add()
        new_node.op = node.op
        new_node.name = node.name
        new_node.input.extend(node.input)

        #對attrs屬性進行復制,attrs屬性主要關注
        attrs = list(node.attr.keys())
        # BatchNorm屬性保持不變
        if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
            for attr in attrs:
                new_node.attr[attr].CopyFrom(node.attr[attr])
            continue
        # 除了BatchNorm以外其他計算節點的屬性單獨
        for attr in attrs:
            # 對指定的計算節點保持不變
            if node.name in keep_fp32_node_name:
                new_node.attr[attr].CopyFrom(node.attr[attr])
                continue
            #將Float類型修改爲設置的目標類型
            if node.attr[attr].type == types_pb2.DT_FLOAT:
                # modify node dtype
                node.attr[attr].type = dtype
                
            #重點關注value,weights都是保存在value屬性中
            if attr == "value":
                tensor = node.attr[attr].tensor
                if tensor.dtype == types_pb2.DT_FLOAT:
                    # if float_val exists
                    if tensor.float_val:
                        float_val = tf.make_ndarray(node.attr[attr].tensor)
                        new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
                        continue
                    # if tensor content exists
                    if tensor.tensor_content:
                        tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                        tensor_weights = tf.make_ndarray(tensor)
                        # reshape tensor
                        tensor_weights = np.reshape(tensor_weights, tensor_shape)
                        tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
                        new_node.attr[attr].tensor.CopyFrom(tensor_proto)
                        continue
            new_node.attr[attr].CopyFrom(node.attr[attr])
    # transform graph
    if output_names:
        if not input_name:
            input_name = []
        transforms = ["strip_unused_nodes"]
        target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
    # write graph_def to model
    tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
    print("Converting done ...")

save_path = "test"
name = "output_fp16.pb"
model_path="test.pb"
as_text = False
target_type = 'fp16'
convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names)
# 測試一下轉換後的模型是否能夠加載
sess = load_graph(save_path+"/"+name)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章