tensorflow keras deblurGAN復現

代碼

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
from tensorflow.keras.applications.vgg16 import VGG16

import cv2
import PIL
import json, os
import sys

import labelme
import labelme.utils as utils
import glob
import itertools

class DebulgGan():
    def __init__(self):
        self.image_shape = (256,256,3)
        self.ngf = 64
        self.ndf = 64
        self.input_nc = 3
        self.output_nc = 3
        self.input_shape_generator = (256, 256, 3)
        self.n_blocks_gen = 9
        self.epochs = 100
        self.batch_size = 5
        self.train_number = 20000
        self.blur_path = r'F:\BaiduNetdiskDownload\deblugData\train\x'
        self.sharp_path = r'F:\BaiduNetdiskDownload\deblugData\train\y'
        self.img_savepath =r'C:\Users\Administrator\Desktop\photo'
        self.model_path = r'C:\Users\Administrator\Desktop\photo\deblurGAN.h5'
        # define Net
        self.generator = self.generator_model()
        self.discriminator = self.discriminator_model()
        self.model = self.generator_containing_discriminator_multiple_outputs()
        self.loss_model = self.bulid_loss_model()
    def res_block(self,input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):

        x = layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(input)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        if use_dropout:
            x = layers.Dropout(0.5)(x)

        x = layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)
        x = layers.BatchNormalization()(x)

        # 輸入和輸出之間連接兩個卷積層
        merged = layers.Add()([input, x])
        return merged

    def generator_model(self):
        """構建生成模型"""
        # Current version : ResNet block
        inputs = keras.Input(shape=self.image_shape)


        x = layers.Conv2D(filters=self.ngf, kernel_size=(7, 7), padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        # Increase filter number
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            x = layers.Conv2D(filters=self.ngf * mult * 2, kernel_size=(3, 3), strides=2, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)

        # 9 ResNet blocks
        mult = 2 ** n_downsampling
        for i in range(self.n_blocks_gen):
            x = self.res_block(x, self.ngf * mult, use_dropout=True)

        # 減少卷積核到3個 (RGB)
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            x = layers.Conv2DTranspose(filters=int(self.ngf * mult / 2), kernel_size=(3, 3), strides=2, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)

        x = layers.Conv2D(filters=self.output_nc, kernel_size=(7, 7), padding='same')(x)
        x = layers.Activation('tanh')(x)

        # Add direct connection from input to output and recenter to [-1, 1]
        outputs = layers.Add()([x, inputs])
        outputs = layers.Lambda(lambda z: z / 2)(outputs)

        model = keras.Model(inputs=inputs, outputs=outputs, name='Generator')
        return model

    def discriminator_model(self):
        """構建判別模型."""
        n_layers, use_sigmoid = 3, False
        inputs = keras.Input(shape=self.image_shape)

        x = layers.Conv2D(filters=self.ndf, kernel_size=(4, 4), strides=2, padding='same')(inputs)
        x = layers.LeakyReLU(0.2)(x)

        nf_mult, nf_mult_prev = 1, 1
        for n in range(n_layers):
            nf_mult_prev, nf_mult = nf_mult, min(2 ** n, 8)
            x = layers.Conv2D(filters=self.ndf * nf_mult, kernel_size=(4, 4), strides=2, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.LeakyReLU(0.2)(x)

        nf_mult_prev, nf_mult = nf_mult, min(2 ** n_layers, 8)
        x = layers.Conv2D(filters=self.ndf * nf_mult, kernel_size=(4, 4), strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)

        x = layers.Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same')(x)
        if use_sigmoid:
            x = layers.Activation('sigmoid')(x)

        x = layers.Flatten()(x)
        x = layers.Dense(1024, activation='tanh')(x)
        x = layers.Dense(1, activation='sigmoid')(x)

        model = keras.Model(inputs=inputs, outputs=x, name='Discriminator')
        return model

    def generator_containing_discriminator_multiple_outputs(self):
        inputs = keras.Input(shape=self.image_shape)
        generated_images = self.generator(inputs)
        outputs = self.discriminator(generated_images)
        model = keras.Model(inputs=inputs, outputs=[generated_images, outputs])
        return model
    def bulid_loss_model(self):#define loss model
        vgg = VGG16(include_top=False, weights='imagenet', input_shape=self.image_shape)
        loss_model = keras.Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
        loss_model.trainable = False
        return loss_model

    def perceptual_loss(self,y_true, y_pred):#perceptual loss for discriminator_loss
        return tf.reduce_mean(K.square(self.loss_model(y_true) - self.loss_model(y_pred)))
    def wasserstein_loss(self,y_true, y_pred):#wasserstein loss for generator_loss
        return tf.reduce_mean(y_true * y_pred)

    def compile(self):
        # self.discriminator.trainable = True
        self.discriminator.compile(optimizer=keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
                                   loss=self.wasserstein_loss)
        self.discriminator.trainable = False
        self.model.compile(optimizer=keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
                           loss=[self.perceptual_loss, self.wasserstein_loss],
                           loss_weights=[100, 1])
        # self.discriminator.trainable = True

    def load_data(self,blur_imgs,sharp_imgs,trian_idx,step):
        blur_img = []
        sharp_img = []

        for j in range(self.batch_size):
            idx = trian_idx[step * self.batch_size + j]
            img = (cv2.imread(blur_imgs[idx], 1)-127.5)/255
            label = (cv2.imread(sharp_imgs[idx], 1)-127.5)/255
            blur_img.append(img)
            sharp_img.append(label)

        return np.array(blur_img),np.array(sharp_img)

    def train(self):
        self.compile()
        self.model.summary()
        #---------------------------------------------------load image location
        blur_location = glob.glob(self.blur_path + '/*.png')
        blur_location.sort()
        sharp_location = glob.glob(self.sharp_path + '/*.png')
        sharp_location.sort()
        train_idx = np.arange(0, self.train_number, 1)
        steps = int(self.train_number/self.batch_size)
        output_true_batch, output_false_batch = np.ones((self.batch_size, 1)), -np.ones((self.batch_size, 1))
        #---------------------------------------------------
        for epoch in range(self.epochs):
            train_idx = (tf.random.shuffle(train_idx)).numpy()# index shuffle

            for step in range(steps):
                blur_imgs ,sharp_imgs = self.load_data(blur_location,sharp_location,train_idx,step)#read img_batch
                gan_imgs = self.generator.predict(blur_imgs)

                d_loss_real = self.discriminator.train_on_batch(sharp_imgs, output_true_batch)
                d_loss_fake = self.discriminator.train_on_batch(gan_imgs, output_false_batch)
                discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

                self.discriminator.trainable = False
                # Train generator only on discriminator's decision and generated images
                generator_loss = self.model.train_on_batch(blur_imgs, [sharp_imgs, output_true_batch])

                print("epoch:%d step:%d [discriminator_loss: %f] [generator_loss: %f]" % (
                    epoch, step, discriminator_loss, generator_loss[0]))
                if step%500 ==0:
                    self.generate_sample_images(gan_imgs,sharp_imgs,epoch,step)
            self.model.save(self.model_path)  # 每個epoch存儲模型
            print('save model')
    def generate_sample_images(self, gan_imgs,sharp_imgs,epoch,step):


        idx = 0
        blur =((gan_imgs[idx]+1)*127.5-0.0001).astype(np.uint8)
        sharp = ((sharp_imgs[idx]+1)*127.5-0.0001).astype(np.uint8)
        print((self.img_savepath + "/%d.%d_blur.png" % (epoch,step)))
        cv2.imwrite((self.img_savepath + "/%d.%d_blur.png" % (epoch,step)),blur)
        cv2.imwrite((self.img_savepath + "/%d.%d_sharp.png" % (epoch, step)),sharp)

        print('save plot')
deblurGAN = DebulgGan()
deblurGAN.train()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章