tensorflow加載多個固化pb模型

使用tensorflow加載多個pb模型時,會引起變量衝突,解決方法按照如下方法加載模型可解決:

class Model:
    def __init__(self, model_file):
        self.graph = tf.Graph()
        self.graph_def = tf.GraphDef()
        with gfile.FastGFile(model_file, 'rb') as f:
            self.graph_def.ParseFromString(f.read())
        with self.graph.as_default():
            tf.import_graph_def(self.graph_def, name='')

        self.sess = tf.Session(graph=self.graph, config=config)

    def predict(self, images: list):
        output_node = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[-1].name)
        input_x = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[0].name)

        w = input_x.shape[1]
        h = input_x.shape[2]
        data = []

        for img in images:
            img = img.resize((w, h))
            img = np.array(img).astype(float)
            data.append(img)

        feed = {input_x: data}
        out = self.sess.run(output_node, feed)
        return out

由於之前沒有定義輸入、輸出節點name,故採用節點索引的方式獲取輸入、輸出節點tensor

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章