TensorFlow2-對抗生成網絡

TensorFlow2對抗生成網絡

What i can not create, i do not understand. 我不能創造的東西,我當然不能理解它。

簡介

對抗生成網絡(GAN)是時下非常熱門的一種神經網絡,它主要用於復現數據的分佈(distribution,或者數據的表示(representation))。儘管數據的分佈非常的複雜,但是依靠神經網絡強大的學習能力,可以學習其中的表示。其中,最典型的技術就是圖像生成。GAN的出現是神經網絡技術發展極具突破的一個創新。從2014年GAN誕生之時只能和VAE旗鼓相當,到2018年WGAN的以假亂真,GAN的發展是迅速的。

原理

GAN網絡由兩個部分組成,它們是生成器(Generator)和判別器(Discriminator)。將輸入數據與生成器產生的數據同時交給判別器檢驗,如果兩者的分佈接近(p_g接近p_r),則表示生成器逐漸學習數據的分佈,當接近到一定程度(判別器無法判別生成數據的真假),認爲學習成功。
因此關於生成器G和判別器D之間的優化目標函數如下,這就是GAN網絡訓練的目標。
minGmaxDL(D,G)=Expr(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]=Expr(x)[logD(x)]+Expz(x)[log(1D(x)] \begin{aligned} \min _{G} \max _{D} L(D, G) &=\mathbb{E}_{x \sim p_{r}(x)}[\log D(x)]+\mathbb{E}_{z \sim p_{z}(z)}[\log (1-D(G(z)))] \\ &=\mathbb{E}_{x \sim p_{r}(x)}[\log D(x)]+\mathbb{E}_{x \sim p_{z}(x)}[\log (1-D(x)]\end{aligned}
衡量兩種分佈之間的距離,GAN使用JS散度(基於KL散度推導)衡量兩種分佈的差異,然而當兩種分佈(生成器分佈和真實分佈)直接沒有交叉時,KL散度總是0,JS散度總是log2,這就導致JS散度無法很好量化兩種分佈的差異。同時,此時的將會出現梯度彌散,這也是很多GAN網絡難以訓練的原因。

因此,有人提出了衡量兩種分佈P和Q之間差異的方式是從P分佈到Q分佈需要經歷的變化(代價),可以理解爲下圖的一種分佈變爲另一種分佈需要移動的磚塊數目(移土距離,Earth Mover’s Distance, EM距離)。
B(γ)=xp,xqγ(xp,xq)xpxq B(\gamma)=\sum_{x_{p}, x_{q}} \gamma\left(x_{p}, x_{q}\right)\left\|x_{p}-x_{q}\right\|
W(P,Q))=minγΠB(γ) W(P, Q))=\min _{\gamma \in \Pi} B(\gamma)
在這裏插入圖片描述
在這裏插入圖片描述
基於此提出了Wasserstein Distance距離如下,將網絡中的JS散度替換爲Wasserstein Distance的GAN,稱爲WGAN,它可以從根本上結局不重疊的分佈距離難以衡量的問題從而避免訓練早期的梯度彌散。(必須滿足1-Lipschitz function,爲了滿足這個條件要進行weight clipping,但是即使weight clipping也不一定可以滿足1-Lipschitz function條件。)
W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy] W\left(\mathbb{P}_{r}, \mathbb{P}_{g}\right)=\inf _{\gamma \in \Pi\left(\mathbb{P}_{r}, \mathbb{P}_{g}\right)} \mathbb{E}_{(x, y) \sim \gamma}[\|x-y\|]

因此,爲了滿足這個條件提出了WGAN-GP(Gradient Penalty),將這個條件寫入損失函數,要求必須在1附近。
在這裏插入圖片描述

GAN發展

從GAN思路被提出以來,產生了各種各樣的GAN,每一種GAN都有自己的名字,一般以首字母簡略稱呼(如今A-Z已經幾乎用完,可見這幾年GAN的發展迅速)。
在這裏插入圖片描述
其中,比較著名的有DCGAN(反捲積GAN,用於圖片擴張)。
在這裏插入圖片描述
此外,還有LSGAN、WGAN(儘管效果不如DCGAN,但是不需要花太多精力設計訓練過程)等。

GAN實戰

基於日本Anime數據集生成相應的二次元人物頭像,數據集的百度網盤地址如下,提取碼g5qa。
構建的GAN模型結構示意如下,判別器是一個基礎的CNN分類器,生成器是將隨機生成的數據進行升維成圖。
在這裏插入圖片描述
下面給出模型結構代碼,具體的訓練代碼可以在文末Github找到。

"""
Author: Zhou Chen
Date: 2019/11/23
Desc: About
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()
        # 升維成圖
        # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
        self.fc = layers.Dense(3 * 3 * 512)

        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

    def call(self, inputs, training=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        x = tf.nn.leaky_relu(x)
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        x = tf.tanh(x)  # 不使用relu

        return x


class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()
        # 分類器
        # [b, 64, 64, 3] => [b, 1]
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')

        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        # [b, h, w ,c] => [b, -1]
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

    def call(self, inputs, training=None):
        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        # [b, h, w, c] => [b, -1]
        x = self.flatten(x)
        # [b, -1] => [b, 1]
        logits = self.fc(x)

        return logits


def main():
    d = Discriminator()
    g = Generator()

    x = tf.random.normal([2, 64, 64, 3])
    z = tf.random.normal([2, 100])

    prob = d(x)
    print(prob)
    x_hat = g(z)
    print(x_hat.shape)


if __name__ == '__main__':
    main()

WGAN只需要在GAN代碼基礎上添加懲罰項,具體見Github。

補充說明

  • 本文介紹了GAN在TensorFlow2中的實現,更詳細的可以查看官方文檔。
  • 具體的代碼同步至我的Github倉庫歡迎star;博客同步至我的個人博客網站,歡迎查看其他文章。
  • 如有疏漏,歡迎指正。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章