tensorflow session and graph

tensorflow session and graph

1. set_session—clear_session—get_session


class Recog_Fish(object):

    def __init__(self,kerasTextModel,IMGSIZE,keras_anchors,class_names):
        self.kerasTextModel = kerasTextModel
        self.IMGSIZE = IMGSIZE
        self.keras_anchors = keras_anchors
        self.class_names = class_names
        self.box_score = None
        self.text_detect_graph = tf.Graph()
        self.load_model = self.kerasTextModel
        
        
    def creat_graph(self):
        anchors = [float(x) for x in self.keras_anchors.split(',')]
        anchors = np.array(anchors).reshape(-1, 2)

        num_classes = len(self.class_names)
        
        config = tf.ConfigProto()  
        # config.gpu_options.allow_growth = True 
        config.gpu_options.per_process_gpu_memory_fraction=0.1
        sess = tf.Session(config=config)
        set_session(sess)
        keras.backend.clear_session()
        
        with self.text_detect_graph.as_default():
            textModel = yolo_text(num_classes, anchors)
            textModel.load_weights(self.load_model)
            self.image_shape = K.placeholder(shape=(2,))  ##圖像原尺寸:h,w
            self.input_shape = K.placeholder(shape=(2,))  ##圖像resize尺寸:h,w
            self.box_score = box_layer([*textModel.output, self.image_shape, self.input_shape], anchors, num_classes)
            self.textModel = textModel
            
#           self.sess = tf.Session(graph=self.text_detect_graph)
            self.sess = K.get_session()

2. with self.graph.as_default():


class FLATE():
    def __init__(self,model_seq_rec,kerastextmodel,typeDistinguish_model,newenergy_model):
        self.graph = tf.get_default_graph()
        with self.graph.as_default():
            self.modelSeqRec = self.model_seq_rec(model_seq_rec)
            
        self.detect_model=keras_detect.Text_Detect(kerasTextModel=kerastextmodel,IMGSIZE=config.IMGSIZE,
                            keras_anchors=config.keras_anchors,
                            class_names=config.class_names)
        self.detect_model.creat_graph()
        with self.graph.as_default():
            self.type_model = self.Getmodel_tensorflow(5)
            self.type_model.load_weights(typeDistinguish_model)
            self.type_model.save(typeDistinguish_model)
        with self.graph.as_default():
            self.NewenergyModel = self.get_NewEnergyModel()
            try:
                self.NewenergyModel.load_weights(newenergy_model)
            except:
                raise Exception("No newenergy weight file!")

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章