tensorflow keras 語義分割U-net二分類網絡

代碼

class UNet():
  def __init__(self,
               input_width,
               input_height,
               num_classes,
               train_class,
               train_images,
               train_instances,
               val_images,
               val_instances,
               epochs,
               lr,
               lr_decay,
               batch_size,
               model_path,
               save_path,
               train_mode
                  ):
      self.input_width=input_width
      self.input_height=input_height
      self.num_classes=num_classes
      self.train_class=train_class
      self.train_images=train_images
      self.train_instances=train_instances
      self.val_images=val_images
      self.val_instances=val_instances
      self.epochs=epochs
      self.lr=lr
      self.lr_decay=lr_decay
      self.batch_size=batch_size
      self.model_path=model_path
      self.save_path=save_path
      self.train_mode=train_mode
#--------------------------------------------------------------定義U—net網絡結構
  def leftNetwork(self, inputs):  # U-net網絡左側下采樣結構
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
    o_1 = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_1)

    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    o_2 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_2)

    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    o_3 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_3)

    x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    o_4 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)

    x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_4)
    x = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x)
    o_5 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x)


    print(o_1, o_2, o_3, o_4, o_5)
    return [o_1, o_2, o_3, o_4, o_5]

  def rightNetwork(self, inputs, num_classes, activation):  # U-net網絡右側上採樣結構
    c_1, c_2, c_3, c_4, c_5 = inputs

    x = layers.UpSampling2D((2, 2))(c_5)
    print('1', x)
    x = layers.concatenate([c_4, x], axis=3)
    x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = layers.UpSampling2D((2, 2))(x)
    print('2', x)
    x = layers.concatenate([c_3, x], axis=3)
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.UpSampling2D((2, 2))(x)
    print('3', x)
    x = layers.concatenate([c_2, x], axis=3)
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.UpSampling2D((2, 2))(x)
    print('4', x)
    x = layers.concatenate([c_1, x], axis=3)
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(num_classes, (1, 1), strides=(1, 1), padding='same')(x)
    x= layers.Reshape([self.input_height, self.input_width])(x)
    x = layers.Activation(activation)(x)

    return (x)

  def U_net(self, inputs, num_classes, activation):  # U-net網絡結構
    leftout = self.leftNetwork(inputs)
    outputs = self.rightNetwork(leftout, num_classes, activation)
    return outputs
#---------------------------------------------------------------
  def build_mode(self):#定義建立結構的方法
    inputs = keras.Input(shape=[self.input_height,self.input_width,3])
    outputs= self.U_net(inputs,num_classes=self.num_classes,activation='sigmoid')

    model =keras.Model(inputs=inputs , outputs=outputs)
    return model





  def dataGenerator(self,mode):#定義 數據生成器
    zeroMat=np.zeros(shape=[self.input_height,self.input_width])
    if mode =='training':#訓練的數據
      images = glob.glob(self.train_images+'/*.jpg')
      images.sort()


      instances= glob.glob(self.train_instances +'/*.png')
      instances.sort()

      zipped = itertools.cycle(zip(images,instances))
      while True :
        x_train=[]
        y_train=[]
        for _ in range(self.batch_size):
          img,seg = next(zipped)
          img = cv2.imread(img,1)/255
          #----------------------------------------------------------------------------------------改變的地方

          seg = cv2.imread(seg, 0)
          if (self.train_class):
            seg = np.where(seg == self.train_class, 1, 0)

          # ----------------------------------------------------------------------------------------
          # seg = keras.utils.to_categorical(seg,num_classes=self.num_classes)

          x_train.append(img)
          y_train.append(seg)
        yield np.array(x_train),np.array(y_train)
    if mode == 'validation':#測試的數據
      images = glob.glob(self.val_images + '/*.jpg')#17年的數據用Jpg存放
      images.sort()

      instances = glob.glob(self.val_instances + '/*.png')#標籤用PNG存放
      instances.sort()
      zipped = itertools.cycle(zip(images,instances))
      while True:
        x_eval = []
        y_eval = []
        img,seg = next(zipped)
        img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height))/255

        #----------------------------------------------------------------------------------------


        seg = cv2.imread(seg, 0)
        if (self.train_class):
          seg = np.where(seg == self.train_class, 1, 0)

        # ----------------------------------------------------------------------------------------
        # seg = keras.utils.to_categorical(seg,num_classes=self.num_classes)
        x_eval.append(img)
        y_eval.append(seg)
        yield np.array(x_eval), np.array(y_eval)

  def multi_category_focal_loss(self,y_true, y_pred):
    epsilon = 1.e-7
    gamma = 2.0
    # alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)
    alpha = tf.constant([[1], [1], [1], [1], [1]], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
    y_t = tf.multiply(y_true, y_pred) + tf.multiply(1 - y_true, 1 - y_pred)
    ce = -K.log(y_t)
    weight = tf.pow(tf.subtract(1., y_t), gamma)
    fl = tf.matmul(tf.multiply(weight, ce), alpha)
    loss = tf.reduce_mean(fl)
    return loss
  def focal_loss(self,y_true, y_pred):  # 定義損失函數
    gamma = 1.5
    alpha = 0.9
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))

    pt_1 = K.clip(pt_1, 1e-3, .999)
    pt_0 = K.clip(pt_0, 1e-3, .999)

    return -tf.reduce_mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) - tf.reduce_mean(
      (1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))

  def train(self):#定義訓練過程



    G_train =self.dataGenerator(mode='training')
    G_eval =self.dataGenerator(mode='validation')
    if (self.train_mode):
      model=keras.models.load_model(self.model_path,custom_objects={'focal_loss': self.focal_loss})
    else:
      model =self.build_mode()#實例化
    model.summary()
    model.compile(
      optimizer=keras.optimizers.Adam(self.lr,self.lr_decay),
      loss ='binary_crossentropy',#構造損失函數
      metrics=['binary_accuracy', 'Recall','AUC']#構造評價函數
    )
    checkpoint = keras.callbacks.ModelCheckpoint(self.save_path, monitor='val_Recall', verbose=1,
                                                 save_best_only=True, mode='max')

    callbacks = [checkpoint]
    model.fit_generator(G_train,2000,validation_data=G_eval,validation_steps=30,epochs=self.epochs,callbacks=callbacks)
    model.save(self.save_path)#保存模型
  def modelPred(self):#模型預測函數
    model = keras.models.load_model(self.model_path,custom_objects={'multi_category_focal_loss1': self.multi_category_focal_loss})
    model.summary()
    images = glob.glob(self.val_images + '/*.jpg')#17年的數據用Jpg格式存放
    images.sort()
    instances = glob.glob(self.val_instances + '/*.png')#標籤用tif存放
    instances.sort()
    zipped = itertools.cycle(zip(images,instances))

    for _ in range(10):
      img,seg = next(zipped)
      img = cv2.resize(cv2.imread(img, -1), (self.input_width, self.input_height))/255
      seg = cv2.imread(seg, 0)
      x1_eval=np.expand_dims(img,0)
      pred=tf.squeeze(tf.argmax(model.predict(x1_eval),axis=-1))

      plt.subplot(121)
      plt.title("pred")
      plt.imshow(pred)

      plt.subplot(122)
      plt.title("pred")
      plt.imshow(seg)

      plt.show()
if __name__ == '__main__':
  unet=UNet(#開始模型的實例化,每個類別訓練一個網絡
    input_width=256,#圖片resize成這個大小
    input_height=256,
    num_classes=1,#檢測類別
    train_class=4,#訓練第幾個類別
    train_images=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\train_x',#訓練數據存放的地方
    train_instances=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\train_y',
    val_images=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\test_x',#測試數據存放的地方
    val_instances=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\test_y',
    epochs=200,
    lr=0.0001,
    lr_decay=0.000001,
    batch_size=4,
    model_path=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\U_netClass3.h5',
    save_path=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\U_netClass4.h5',#模型存儲絕對路徑
    train_mode=0
)
  unet.train()#開始訓練
  # unet.modelPred()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章