使用tensorflow加載多個pb模型時,會引起變量衝突,解決方法按照如下方法加載模型可解決:
class Model:
def __init__(self, model_file):
self.graph = tf.Graph()
self.graph_def = tf.GraphDef()
with gfile.FastGFile(model_file, 'rb') as f:
self.graph_def.ParseFromString(f.read())
with self.graph.as_default():
tf.import_graph_def(self.graph_def, name='')
self.sess = tf.Session(graph=self.graph, config=config)
def predict(self, images: list):
output_node = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[-1].name)
input_x = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[0].name)
w = input_x.shape[1]
h = input_x.shape[2]
data = []
for img in images:
img = img.resize((w, h))
img = np.array(img).astype(float)
data.append(img)
feed = {input_x: data}
out = self.sess.run(output_node, feed)
return out
由於之前沒有定義輸入、輸出節點name,故採用節點索引的方式獲取輸入、輸出節點tensor