tensorflow keras 完整GAN網絡代碼(面向對象) 利用MNIST手寫數據集生成手寫數字

前言

鑑於網上對於GAN網絡代碼的結構不太好,甚至無法做到迭代,我決定就GAN網絡來寫一個擁有能夠一目瞭然的完整結構的代碼,以幫助那些和我一樣剛開始接觸這類網絡的人,本篇中的GAN網絡由全連接層組成,以此來複現最簡單的GAN網絡結構。

一、代碼結構

代碼由全局量、生成器、判別器、GAN網絡、訓練、範例圖片生成以及載入模型生成圖片這幾個結構組成

二、代碼

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image


import cv2
import PIL
import json, os
import sys

import labelme
import labelme.utils as utils
import glob
import itertools

class GAN():
    def __init__(self,      #定義全局變量
                 ):
        self.img_shape = (28, 28, 1)
        self.save_path = r'C:\Users\Administrator\Desktop\photo\GAN.h5'
        self.img_path = r'C:\Users\Administrator\Desktop\photo'
        self.batch_size = 20
        self.latent_dim = 100
        self.sample_interval=1
        self.epoch=100
        #建立GAN模型的方法
        self.generator_model = self.build_generator()
        self.discriminator_model = self.build_discriminator()
        self.model = self.bulid_model()


    def build_generator(self):#生成器

        input=keras.Input(shape=self.latent_dim)

        x=layers.Dense(256)(input)
        x=layers.LeakyReLU(alpha=0.2)(x)
        x=layers.BatchNormalization(momentum=0.8)(x)

        x = layers.Dense(512)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.BatchNormalization(momentum=0.8)(x)

        x = layers.Dense(1024)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.BatchNormalization(momentum=0.8)(x)

        x=layers.Dense(np.prod(self.img_shape),activation='sigmoid')(x)
        output=layers.Reshape(self.img_shape)(x)

        model=keras.Model(inputs=input,outputs=output,name='generator')
        model.summary()
        return model

    def build_discriminator(self):#判別器

        input=keras.Input(shape=self.img_shape)

        x=layers.Flatten(input_shape=self.img_shape)(input)
        x=layers.Dense(512)(x)
        x=layers.LeakyReLU(alpha=0.2)(x)
        x=layers.Dense(256)(x)
        x=layers.LeakyReLU(alpha=0.2)(x)
        output=layers.Dense(1,activation='sigmoid')(x)

        model=keras.Model(inputs=input,outputs=output,name='discriminator')
        model.summary()
        return model

    def bulid_model(self):#建立GAN模型
        self.discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=keras.optimizers.Adam(0.0001, 0.000001),
                                    metrics=['accuracy'])

        self.discriminator_model.trainable = False#使生成器不訓練

        inputs = keras.Input(shape=self.latent_dim)
        img = self.generator_model(inputs)
        outputs = self.discriminator_model(img)
        model = keras.Model(inputs=inputs, outputs=outputs)
        model.summary()
        model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.000001),
                      loss='binary_crossentropy',
                      )
        return model

    def load_data(self):
        (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
        train_images = train_images /255
        train_images = np.expand_dims(train_images, axis=3)
        print('img_number:',train_images.shape)
        return train_images

    def train(self):
        train_images=self.load_data()#讀取數據

        #生成標籤
        valid = np.ones((self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1))

        step=int(train_images.shape[0]/self.batch_size)#計算步長
        print('step:',step)

        for epoch in range(self.epoch):
            train_images = (tf.random.shuffle(train_images)).numpy()#每個epoch打亂一次
            if epoch % self.sample_interval == 0:
                self.generate_sample_images(epoch)

            for i in range(step):

                idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引
                imgs =train_images[idx]#讀取索引對應的圖片
                noise = np.random.normal(0, 1, (self.batch_size, 100))  # 生成標準的高斯分佈噪聲
                gan_imgs = self.generator_model.predict(noise)#通過噪聲生成圖片
                #----------------------------------------------訓練判別器
                discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid)  # 真實數據對應標籤1
                discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake)  # 生成的數據對應標籤0
                discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
                #----------------------------------------------- 訓練生成器
                noise = np.random.normal(0, 1, (self.batch_size, 100))
                generator_loss = self.model.train_on_batch(noise, valid)
                if i%10==0:#每十步進行輸出
                    print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
                        epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))



        self.model.save(self.save_path)#存儲模型

    def generate_sample_images(self, epoch):#生成圖片

        row, col = 5, 5#行列的數字
        noise = np.random.normal(0, 1, (row * col, self.latent_dim))#生成噪聲
        gan_imgs = self.generator_model.predict(noise)
        fig, axs = plt.subplots(row, col)#生成5*5的畫板
        idx = 0

        for i in range(row):
            for j in range(col):
                axs[i, j].imshow(gan_imgs[idx, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                idx += 1
        fig.savefig(self.img_path+"/%d.png" % epoch)
        plt.close()#關閉畫板

    def pred(self):#載入模型並生成圖片
        model=keras.models.load_model(self.save_path)
        model.summary()
        noise = np.random.normal(0, 1, (1, self.latent_dim))

        generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output)
        generator.summary()
        img=np.squeeze(generator.predict([noise]))
        plt.imshow(img)
        plt.show()
        print(img.shape)



if __name__ == '__main__':
    GAN = GAN()
    GAN.train()


三、實驗現象

在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
可以看到網絡生成的數字越來越逼真了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章