使用keras 搭建Gans在Mnist數據集上訓練總結

Gan的基本介紹

GAN(Generative Adversarial Networks)被Lan Goodfellow提出以後,各種Gan遍地開花,GAN掀起了一場技術革命在各個領域的應用都取得了重大突破, 圖靈獎得主Yann LeCun也稱近Gan是20年來深度學習領域最棒的想法。身爲小白的我也久仰Gan的大名,在寒假期間終於有時間能實操一下Gan的訓練(期待的搓搓小手(ง˃̀ꄃ˂́)۶)
在這裏插入圖片描述
GAN的核心思想是博弈,將網絡劃分爲生成器(Generator)和判別器(Discriminator),生成器並不直接接觸真實數據生成一張Fake image企圖騙過後面的判別器,而判別器則要不斷提高自身的判別能力以判別出圖像的真假,兩個網絡不斷博弈最終就能得到一個能夠生成逼真圖像的Generator, 而Generator是沒有直接看過真實圖像的。
在這裏插入圖片描述

Gan的訓練過程

訓練Gan時首先訓練判別器然後固定判別器的參數,再給generator輸入低維噪聲得到Fake Image交給判別器判別計算loss反向傳播更新Generator的參數。當判別器的準確率不足或者說生成器的擬合能力已經達到了判別器判別能力的上限再用帶真假標籤的真圖和假圖訓練更新判別器的參數,如此循環。
生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是爲了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。
在這裏插入圖片描述

Autoencoder

說到Gan就不得不提同樣是尋找低維特徵來表示高維數據分佈的自編碼器(Autoencoder)編碼器通過選擇和抽取特徵將數據編碼到低維隱層特徵空間z中,解碼器則相反。自編碼器的損失函數就是編碼時最大化信息保留和解碼之後最小化重構損失。Autoencoder通過自學習的方式能夠學習到由低維隱層空間解碼到原圖的方法,其中Decoder做的事情跟Gan的Generator的工作就很類似,所以Decoder也曾經被作爲一種生成器但效果不好,目前Autoencoder主要用來進行數據降維和特徵抽取(通過降低隱層特徵空間(編碼空間)的維度,達到降維的作用)。研究Autoencoder的意義在於更好的理解Generator在做些什麼。


自己實現的加入噪聲的deepAutoencoder效果還不錯(第一行爲原圖+高斯噪聲,第二行爲Autoencoder輸出,第三行爲原圖)
在這裏插入圖片描述

Vae

Vae(Variational Auto-Encoder)可以說是AE的豪華升級版,VAE的最大特點是模仿自動編碼機的學習預測機制,在可測函數之間進行編碼、解碼。同GAN類似,其最重要的idea是基於一個令人驚歎的數學事實:對於一個目標概率分佈,給定任何一種概率分佈,總存在一個可微的可測函數,將其映射到另一種概率分佈,使得這種概率分佈與目標的概率分佈任意的接近
對於未知的隱層空間z的分佈Gan是採用暴力破解的方法直接通過對抗擬合到z的分佈。而VAE的思想則不同,VAE假設z空間中的每一維特徵都影響着真實空間的某個特徵,例如下圖中的結果中從左上到右下字符逐漸傾倒且圈圈逐漸消失,就可以認爲x軸(z空間的某一維度)對應真實字符有沒有圈圈,y軸(z空間的另一維度)對應字符的傾倒程度。對於GAN的暴力求解,VAE的建模思路無疑要複雜的多,它更能體現思維的藝術感。
由於VAE在訓練過程中引入了噪聲(下圖中的e)使VAE具有了產生一些沒有見過的圖片的能力,比如訓練集中只有半月和滿月的圖片VAE能夠生成辦滿月的圖片。另外VAE的損失函數由兩部分構成:
1.重構損失函數(inputs和outputs的交叉熵)
2.學習到的隱分佈和先驗分佈的kl距離
模型的loss爲這兩項的和
在這裏插入圖片描述
手動實現的VAE結果展示:
在這裏插入圖片描述

Gan到DCGan

在我們的印象中: 卷積 = 對圖像處理來說很有用, GANs = 適合生成一些東西, 所以 卷積+GANs = 適合生成圖像? 於是DCGan(Deep Convolution Gan)應運而生。DCGan主要在一下幾點對Gan進行了改進:
1、G,D網絡不採取任何池化
2、G,D網絡每一層均使用批標準化處理(Batch-Normalization)
3、在G網絡中,激活函數除了最後一層外,都是用Relu函數,最後一層使用 tanh函數
4、D網絡中,激活函數均使用Leaky Relu函數。

由於替換掉了全連接,DCGan中使用了反捲和卷積操作實現數據的升降維,其中反捲積操作本質上應該叫轉置卷積(Transpose Convolution)除此之外DCGan與Gan並沒有什麼結構上的差異。

Gan(網圖 有些過擬合):

實現DCGan:
在這裏插入圖片描述
Tip:在利用公式計算Generator中的轉置卷積輸出維度時可能會遇到與對應的卷積操作不對應的情況,這是由於計算卷積時如果遇到輸出維度帶小數一般會取整,導致計算轉置卷積時得不到正確輸出維度,這時可以先用Generator的Output進行卷積倒推出Input,這樣得到的隱層空間維度就是正確的維度。

Gan到CGan(Conditional Gan)

原始的Gan只是學習到了輸入一個噪聲生成一個“數字”說生成一個看起來像“數字”的字符,但是Gan自己都不知道它生成的是哪個數字,Gan不像VAE能夠控制Encoder的輸入(z)獲取到沒有見過的圖像,Gan由隱空間到真實空間的轉換完全由Generator自行學習
爲了能讓Gan隨心所欲的生成想要的結果,CGan的作者在Generator的輸入層增加了由One hot 編碼得到的label y, 同時在訓練Discriminator的時候也引入label y
,這樣就使得Generator不僅僅要生成一個數字而且還要像“y”。 條件生成對抗的思想使Gan有了更多的用武之地。
在這裏插入圖片描述
CGan在Mnist上訓練3000輪結果:
在這裏插入圖片描述

CDCGan

我在CGan的基礎上將網絡中引入了卷積,並將ONE HOT替換爲Embedding,優化了網絡結構同時也增加了網絡深度,得到的效果要比CGAN好的多。
CDCGan300輪訓練結果:
在這裏插入圖片描述

Tips: 在搭建CGan時可能會遇到Conditional label與Generator Input或者Discriminator Input的拼接問題,建議使用Embedding即使Conditional label在Generator和Discriminator中輸入的維度不同也能train,而且Embedding層的參數還能自動更新。

由於用到的模型較多這裏只放CDCGan的核心代碼 Github鏈接:答應給我star才能點(๐•̆ ·̭ •̆๐)

class CGan(object):


    def __init__(self, config, weight_path = None):
        """
        CGan初始化函數
        :param config:配置文件
        :param weight_path: 已有權重路徑
        """
        self.config = config
        self.build_cgan_model()
        if weight_path is not None:
            self.cgan.load_weights(weight_path, by_name = True)


    def build_cgan_model(self):
        """
        build cgan model
        :return:
        """
        #初始化輸入
        self.generator_noise_input = Input(shape=(self.config.generator_noise_input_dim,))
        self.discriminator_image_input = Input(shape=self.config.discriminator_image_input_dim)
        self.contational_label_input = Input(shape=(1,), dtype='int32')

        #定義優化器
        self.optimizer = Adam(lr=2e-4, beta_1=0.5)

        #構建生成器模型與判別器模型
        self.discriminator_model = self.build_discriminator_model()
        self.discriminator_model.compile(optimizer=self.optimizer, loss='binary_crossentropy', metrics=['accuracy'])
        self.generator_model = self.build_generator()

        #構建CGan
        self.discriminator_model.trainable = False
        self.cgan_input = [self.generator_noise_input, self.contational_label_input]
        generator_output = self.generator_model(self.cgan_input)

        self.discriminator_input = [generator_output, self.contational_label_input]
        self.cgan_output = self.discriminator_model(self.discriminator_input)
        self.cgan = Model(self.cgan_input, self.cgan_output)

        self.cgan.compile(optimizer=self.optimizer, loss='binary_crossentropy')
        plot_model(self.cgan, "./model/CDCGan_Model.png")
        plot_model(self.generator_model, "./model/CDCGan_generator.png")
        plot_model(self.discriminator_model, "./model/CDCGan_discriminator.png")


    def build_discriminator_model(self):
        """

        :return:
        """
        model = Sequential()

        model.add(Conv2D(64, kernel_size=3, strides=2, padding='same',
                         input_shape=self.config.discriminator_image_input_dim))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(rate=self.config.dropout_prob))

        model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(rate=self.config.dropout_prob))

        model.add(Conv2D(256, kernel_size=3, strides=2, padding='same'))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(rate=self.config.dropout_prob))

        model.add(Conv2D(512, kernel_size=3, strides=2, padding='same'))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(rate=self.config.dropout_prob))

        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.config.discriminator_image_input_dim)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = (Embedding(self.config.condational_label_num,
                                              np.prod(self.config.discriminator_image_input_dim))(label))

        label_embedding = Reshape(self.config.discriminator_image_input_dim)(label_embedding)
        model_input = multiply([img, label_embedding])
        validity = model(model_input)

        return Model([img, label], validity)


    def build_generator(self):
        """
        這是構建生成器網絡的函數
        :return:返回生成器模型generotor_model
        """
        model = Sequential()

        model.add(Dense(7*7*256, input_shape=(self.config.generator_noise_input_dim, ), activation='relu'))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Reshape((7, 7, 256)))

        model.add(Conv2DTranspose(128,kernel_size=3, strides=2, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))

        model.add(Conv2DTranspose(64, kernel_size=3, strides=2, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))

        model.add(Conv2DTranspose(32, kernel_size=3,  padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))

        model.add(Conv2DTranspose(self.config.discriminator_image_input_dim[2], kernel_size=3,
                                  padding='same', activation='tanh'))

        model.summary()

        noise = Input(shape=(self.config.generator_noise_input_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.config.condational_label_num, self.config.generator_noise_input_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

環境

操作系統:Windows 64 內存8g
顯卡:GTX1050
python:3.6.9
tensorflow-gpu:1.12
keras: 2.24

參考文獻:

https://www.sohu.com/a/325882199_114877
https://www.sohu.com/a/226209674_500659

結語

寒假一個月收穫不少,明天開始就要全力準備考研了要跟瞎玩模型告一段落了hhh下次寫博客可能就是一年後了 Fighting!!!
ps:深深體會到了窮人不配深度學習,一跑就是好幾天 電腦快給我烤化了,立個flag:明年暑假賺到錢嘗試在雲上跑吼吼吼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章