生成對抗式網絡 GAN及其衍生CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理介紹、應用介紹及簡單Tensorflow實現

生成式對抗網絡(GAN,Generative Adversarial Networks)是一種深度學習模型,是近年來複雜分佈上無監督學習最具前景的方法之一。學界大牛Yann Lecun 曾說,令他最激動的深度學習進展就是生成式對抗網絡。最近正好看了這方面的一些介紹和論文,並用Tensorflow實現了兩個小例子,所以寫了這篇文章來作個簡單的小結。

本文主要分爲四個部分:

1.原始的GAN原理介紹;
2.GAN衍生的CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理;
3.應用介紹;
4.tensorflow實現GAN小例子;

一、GAN原理介紹

學習GAN的第一篇論文當然由是 Ian Goodfellow 於2014年發表的 Generative Adversarial Networks(論文下載鏈接arxiv:[https://arxiv.org/abs/1406.2661] ),這篇論文可謂這個領域的開山之作。

GAN的基本原理其實並不複雜,模型通過框架中兩個模塊:生成模型(Generative Model)和判別模型(Discriminative Model)的互相博弈學習產生相當好的輸出。這裏以生成圖片爲例進行說明。假設我們有兩個網絡,G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:
G是一個生成圖片的網絡,它接收一個隨機的噪聲z,通過這個噪聲生成圖片,記做G(z)。
D是一個判別網絡,判別一張圖片是不是“真實的”。它的輸入參數是x,x代表一張圖片,輸出D(x)代表x爲真實圖片的概率,如果爲1,就代表100%是真實的圖片,而輸出爲0,就代表不可能是真實的圖片。

在訓練過程中,生成網絡G的目標就是儘可能生成真實的圖片去欺騙判別網絡D。而D的目標就是儘量把G生成的圖片和真實的圖片分別開來。這樣,G和D構成了一個動態的“博弈過程”。

最後博弈的結果,在最理想的狀態下,G可以生成足以“以假亂真”的圖片G(z)。對於D來說,它難以判定G生成的圖片究竟是不是真實的,即達到了一個納什均衡,因此D(G(z)) = 0.5。此時,模型的收斂目標是生成器能夠從隨機噪聲生成真實數據。

這樣我們的目的就達成了:我們得到了一個生成式的模型G,它可以用來生成圖片。

以上只是大致說了一下GAN的核心原理,如何用數學語言描述呢?這裏直接摘錄論文裏的公式:
在這裏插入圖片描述
簡單分析一下這個公式:
整個式子由兩項構成。x表示真實圖片,z表示輸入G網絡的噪聲,而G(z)表示G網絡生成的圖片。
D(x)表示D網絡判斷真實圖片是否真實的概率(因爲x就是真實的,所以對於D來說,這個值越接近1越好)。而D(G(z))是D網絡判斷G生成的圖片的是否真實的概率。
G的目的:上面提到過,D(G(z))是D網絡判斷G生成的圖片是否真實的概率,G應該希望自己生成的圖片“越接近真實越好”。也就是說,G希望D(G(z))儘可能得大,這時V(D, G)會變小。因此我們看到式子的最前面的記號是min_G。
D的目的:D的能力越強,D(x)應該越大,D(G(x))應該越小。這時V(D,G)會變大。因此式子對於D來說是求最大(max_D)
下面這幅圖片很好地描述了這個過程:
在這裏插入圖片描述
那麼如何用隨機梯度下降法訓練D和G?論文中也給出了算法:
在這裏插入圖片描述
這裏黃色框圈出的部分是我們要特別注意的。第一步我們訓練D,D是希望V(G, D)越大越好,所以是加上梯度(ascending)。第二步訓練G時,V(G, D)越小越好,所以是減去梯度(descending)。整個訓練過程交替進行。

二、GAN衍生的CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理

(1)CGAN:

在原始GAN中,目的是使得生成器能夠從隨機噪聲中生成真實數據,而CGAN(論文下載鏈接arxiv:https://arxiv.org/pdf/1411.1784.pdf)則更近一層,即給GAN加上條件,指導數據的生成過程,使得生成具有特定性質的樣本。以生成MNIST數據集的圖像樣本來說,原始GAN得到的生成器可以由隨機向量生成一張含有數字的圖像樣本,其中數字可能是0~9中的任意一個,而CGAN則是在生成器輸入時添加一個條件y,使得可以生成符合預期數字的圖像樣本,如生成含有數字1的圖像。如圖Figure1所示。價值函數變化如下:
在這裏插入圖片描述在這裏插入圖片描述
(2)DCGAN:

DCGAN(論文下載鏈接arxiv:https://arxiv.org/abs/1511.06434 )是應用比較廣泛的改進結構,基本採用卷積層替代了原始的全連接層,其中在生成器中採用帶步長的卷積代替了上採樣,極大地提升了GAN訓練時的穩定性及生成結果質量。如圖所示
GAN的主要問題是訓練過程不穩定,而DCGAN改進了其穩定性,原因在於:
(1)幾乎每層都使用batchnorm層,將特徵層的輸出歸一化到一起,加速訓練,提升訓練的穩定性;
(2)判別器中使用Leaky ReLU,防止梯度過度稀疏,生成器則仍然採用 ReLU,但最後輸出層採用Tanh;
(3)使用Adam優化器訓練,且最佳學習率爲0.0002;
(4)使用帶步長卷積替代上採樣層,卷積在提取圖像特徵上有較好的作用,並且使用卷積代替全連接層。
在這裏插入圖片描述
(3)WGAN:

爲了使得GAN的訓練更加穩定,與DCGAN不同的是,WGAN(論文下載鏈接arxiv:https://arxiv.org/pdf/1701.07875.pdf)主要從損失函數的角度進行改進:
A)判別器最後一層去掉Sigmoid;
B)生成器和判別器的loss不取Log;
C)對更新後的權重強制clip,如[-0.01,0.01],以滿足連續性條件;
D)推薦SGD、RMSProp等優化器,不要採用含有動量的優化算法,如Adam。

原始的GAN存在的問題有:判別器越好,生成器梯度消失越嚴重,生成器loss降不下去;判別器不好,生成器梯度不準,訓練不穩定,只有判別器訓練得不好不壞纔行,但這個尺度很難把握,甚至同一輪訓練的不同階段該尺度都不一樣,所以GAN才難以訓練。最小化生成器loss函數,會等價於最小化一個不合理的距離度量,使得最小化生成分佈與真實分佈的KL散度的同時又要最大化兩者的JS散度,導致梯度不穩定,同時也會使得生成器寧可多生成一些重複但較爲“安全”的樣本,也不願意生成多樣性的樣本,從而導致模式崩潰,即多樣性不足。

下圖所示爲標準GAN與WGAN對真實樣本分佈和生成樣本分佈判別的差異,標註GAN會出現梯度消失的情況,而WGAN則有較好的線性梯度。
在這裏插入圖片描述WGAN的貢獻主要在於從理論上給出了GAN訓練不穩定的原因,即交叉熵不適合衡量具有不相交部分的數據之間的距離,轉而使用Wassertein距離去衡量生成數據與真實數據之間的距離,理論上解決了訓練不穩定的問題;解決了模式崩潰問題,生成結果更加多樣;對GAN的訓練提供了一個指標,可以採用此指標來衡量GAN訓練的好壞,而不像之前那樣盲目訓練。

(4)LSGAN:

LSGAN(論文下載鏈接arxiv:https://arxiv.org/pdf/1611.04076.pdf)的主要目的也是採用最小二乘損失函數代替了GAN目標函數的交叉熵,從而解決了GAN訓練不穩定和生成圖像質量差、多樣性不足的問題。
在這裏插入圖片描述
其中a,b,c屬於超參數,a,b分別表示生成圖片和真實圖片的標記,c是生成器爲了使判別器認爲生成圖片爲真實樣本而定的值,這裏設定a=0,b=c=1。

論文主要回答了兩個問題:爲什麼最小二乘損失可以提高生成圖片質量;爲什麼最小二乘損失可以使得GAN訓練更穩定。對於第一個問題,論文認爲交叉熵作爲損失函數,會使得生成器不再優化那些被判別器識別爲真實圖片的生成圖片,即使這些生成圖片距離判別器的決策邊界仍較遠。原因在於生成器只需要完成混淆判別器的目標生成圖片即可,而最小二乘損失則在混淆判別器的前提下還得讓生成器把距離決策邊界較遠的生成圖片拉向決策邊界。對於第二個問題,論文認爲Sigmoid交叉熵損失容易達到飽和狀態,即梯度爲0,而最小二乘只在一個點達到飽和。

(5)BEGAN:

谷歌提出一種新的簡單強大的GAN,這是一種新的評價生成器生成質量的方法,不需要太多的訓練技巧即可實現快速穩定的訓練。以往的GAN及其變體是希望生成器生成的數據分佈儘可能地接近真實數據分佈,因此研究者們設計了各種損失函數,而BEGAN則不採用這種估計概率分佈的方法,即不直接去估計生成分佈Pg和真實分佈Pdata的差距,而是估計分佈的誤差分佈差距,只要分佈之間的誤差分佈相近,也可以認爲這些分佈是相近的。

BEGAN主要有3個貢獻:
(1)提出了一種新的簡單強大的GAN網絡結構,使用標準的訓練方式也能快速穩定的收斂。
(2)對於生成器和判別器的平衡提出了一種均衡的概念,提供了一個超參數,這個超參數用於平衡圖像的多樣性和生成質量。
(3)受WGAN啓發,提出了一種收斂程度估計。
BEGAN採用自編碼器作爲判別器;在生成器的設計上,使用Wasserstein距離衍生出的損失去匹配自編碼器的損失分佈,這是通過傳統的GAN目標加上一個用來平衡判別器和生成器的平衡項來實現的;還提出了一個衡量生成樣本多樣性的超參數Y:生成樣本損失的期望與真實樣本損失的期望值之比。Y值較低時會導致圖像多樣性較差,因爲此時判別器過於關注對真實圖像的自編碼。

三、應用介紹

自誕生以來,GAN引起了衆多學者的注意,成爲近幾年的熱點研究領域,原因在於其代表的無監督學習範式有着廣闊的前景。目前生成式對抗網絡已經有了許多成功的應用,如圖片生成、文字到圖片的合成、圖像超分辨率重建、圖像修復和紋理合成、風格遷移等,此外,GAN在目標檢測、行人識別、重定位等領域也有輔助作用。

因爲自己也是剛剛入門學習中,瞭解的還很片面,感興趣的小夥伴可以百度搜索進一步深入學習。當然,也歡迎掃描文章末尾的二維碼關注公衆號“StrongerTang”做交流和分享。

四、tensorflow實現GAN小例子

學習完以上內容以後,本人蔘考網上分享的代碼用tensorflow實現了Mnist數據圖像生成等兩個簡單小例子,考慮到本文篇幅已經較多,故打算另外單獨寫一篇文章分享。感興趣的朋友可以後續關注一下。

五、小結

GAN自誕生以來便成爲了研究熱點,無論是原理還是應用都取得了極大豐富和發展,並且仍在不斷向前發展中。因而,本文只是冰山一角的分享,還有很多內容自己也還沒有學習,更談不上分享。只希望自己能夠克服愛玩的缺點,多花點心思在學習上,能夠和大家一起學習到更多的知識。
最後這裏分享一個對GAN總結比較好的鏈接,感興趣的小夥伴可以進一步學習。(https://github.com/savan77/The-GAN-World

注:本文參考衆多論文原文及其它網絡資料,在此表示感謝。

在這裏插入圖片描述
?掃描上方二維碼關注
(關注微信公衆號“StrongerTang”,看更多文章,和小湯一起學習,一同進步!)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章