【GAN】使用GAN進行mnist數據集中手寫圖片的生成

一、GAN介紹

GAN一般有兩個內容,一是生成器(generator),二是辨別器(discriminator)。

辨別器的目的是:儘可能地分辨輸入的數據是生成器生成的假數據還是真實的數據

生成器的目的是:儘可能地騙過辨別器,使得辨別器認爲它生成的數據是真實的數據

這是個博弈的過程,能夠使得生成器和辨別器不斷成長,最後生成器能夠生成以假亂真的數據

其中生成器的輸入是隨機向量,輸出是指定的數據

鑑別器的輸入是數據,輸出的是0到1之間的數(意味着數據是真實的數據的概率)

本博客使用的代碼是在tensorflow2.0.0基礎上進行的,主要使用keras

二、代碼分析

1、導入tensorflow模塊

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
from tensorflow.keras.layers import Dense,LeakyReLU,BatchNormalization,Reshape,Flatten
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt

tensorflow內的datasets中有mnist手寫數據集

keras.layers中有能夠直接使用的層

keras.losses中是損失函數

2、載入數據,並做預處理

(train_images,_),(_,_) = datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images-127.5)/127.5

BATCH_SIZE = 256
BUFFER_SIZE = 60000

datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

數據本來在[0,255]之間,將其歸一化到[-1,1]之間,並且reshape多加一個通道維度,最後重構一個數據集

3、定義生成器模型

def generator_model():
    model = keras.Sequential()

    model.add(Dense(256,input_shape=(100,),use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(784,use_bias=False,activation='tanh'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((28,28,1)))

    return model

生成器的輸入是一個隨機的100維向量。

生成器模型由三個全連接層構成,最後一個是輸出層,因爲要輸出28x28的數據,所以最後一個全連接層有784個神經元,並且經過激活函數之後,reshape成爲一張圖片28x28x1,tanh激活函數能夠使得生成的數據在[-1,1]之間

4、定義辨別器模型

def discriminator_model():
    model = keras.Sequential()

    model.add(Flatten())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(256,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(1))

    return model

辨別器由一個平坦層、三個全連接層構成,其中最後一個全連接層只有一個神經元,目的是爲了讓其輸出一個概率

5、定義損失函數和優化器

cross_entropy = BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss + fake_loss

def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out),fake_out)

generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)

EPOCHS = 50
noise_dim = 100

num_exp_to_generate = 16

seed = tf.random.normal([num_exp_to_generate,noise_dim])
generator = generator_model()
discriminator = discriminator_model()

其中real_out的意思是向辨別器輸入真實圖片後,辨別器的輸出,fake_out的意思是向辨別器輸入假圖片後,辨別器的輸出

6、定義訓練步驟

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE,noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        real_out = discriminator(images,training=True)
        gen_image = generator(noise,training=True)
        fake_out = discriminator(gen_image,training=True)
        gen_loss = generator_loss(fake_out)
        dis_loss = discriminator_loss(real_out,fake_out)
    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
    gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))
    return gen_loss,dis_loss

7、定義畫圖函數

def generate_plot_image(gen_model,test_noise,epoch):
    pre_images = gen_model(test_noise,training=False)
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
        plt.axis('off')
    plt.savefig('./images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()

8、開始訓練

def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss,dis_loss = train_step(image_batch)
        print('the ',epoch+1,' epochs have trained')
        print('gen_loss: ',gen_loss,'dis_loss: ',dis_loss)
        generate_plot_image(generator,seed,epoch)
    print('finished')

train(datasets,EPOCHS)

三、訓練結果

訓練大概個位數的epoch後就會隱約能夠看見手寫數字了

訓練50個epoch後的訓練結果如下所示

四、全部代碼

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
from tensorflow.keras.layers import Dense,LeakyReLU,BatchNormalization,Reshape,Flatten
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt

(train_images,_),(_,_) = datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images-127.5)/127.5

BATCH_SIZE = 256
BUFFER_SIZE = 60000

datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

def generator_model():
    model = keras.Sequential()

    model.add(Dense(256,input_shape=(100,),use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(784,use_bias=False,activation='tanh'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((28,28,1)))

    return model

def discriminator_model():
    model = keras.Sequential()

    model.add(Flatten())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(256,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(1))

    return model

cross_entropy = BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss + fake_loss

def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out),fake_out)

generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)

EPOCHS = 50
noise_dim = 100

num_exp_to_generate = 16

seed = tf.random.normal([num_exp_to_generate,noise_dim])
generator = generator_model()
discriminator = discriminator_model()

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE,noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        real_out = discriminator(images,training=True)
        gen_image = generator(noise,training=True)
        fake_out = discriminator(gen_image,training=True)
        gen_loss = generator_loss(fake_out)
        dis_loss = discriminator_loss(real_out,fake_out)
    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
    gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))
    return gen_loss,dis_loss

def generate_plot_image(gen_model,test_noise,epoch):
    pre_images = gen_model(test_noise,training=False)
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
        plt.axis('off')
    plt.savefig('./images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()

def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss,dis_loss = train_step(image_batch)
        print('the ',epoch+1,' epochs have trained')
        print('gen_loss: ',gen_loss,'dis_loss: ',dis_loss)
        generate_plot_image(generator,seed,epoch)
    print('finished')

train(datasets,EPOCHS)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章