簡述生成式對抗網絡

【轉載請註明出處】chenrudan.github.io

本文主要闡述了對生成式對抗網絡的理解,首先談到了什麼是對抗樣本,以及它與對抗網絡的關係,然後解釋了對抗網絡的每個組成部分,再結合算法流程和代碼實現來解釋具體是如何實現並執行這個算法的,最後給出一個基於對抗網絡改寫的去噪網絡運行的結果,效果雖然挺差的,但是有些地方還是挺有意思的。

1. 對抗樣本(adversarial examples)

14年的時候Szegedy在研究神經網絡的性質時,發現針對一個已經訓練好的分類模型,將訓練集中樣本做一些細微的改變會導致模型給出一個錯誤的分類結果,這種雖然發生擾動但是人眼可能識別不出來,並且會導致誤分類的樣本被稱爲對抗樣本,他們利用這樣的樣本發明了對抗訓練(adversarial training),模型既訓練正常的樣本也訓練這種自己造的對抗樣本,從而改進模型的泛化能力[1]。如下圖所示,在未加擾動之前,模型認爲輸入圖片有57.7%的概率爲熊貓,但是加了之後,人眼看着好像沒有發生改變,但是模型卻認爲有99.3%的可能是長臂猿。

圖1 對抗樣本的產生(圖來源[2])

這個問題乍一看很像過擬合,在Goodfellow在15年[3]提到了其實模型欠擬合也能導致對抗樣本,因爲從現象上來說是輸入發生了一定程度的改變就導致了輸出的不正確,例如下圖一,上下分別是過擬合和欠擬合導致的對抗樣本,其中綠色的o和x代表訓練集,紅色的o和x即對抗樣本,明顯可以看到欠擬合的情況下輸入發生改變也會導致分類不正確(其實這裏我覺得有點奇怪,因爲圖中所描述的對抗樣本不一定就是跟原始樣本是同分布的,感覺是人爲造的一個東西,而不是真實數據的反饋)。在[1]中作者覺得這種現象可能是因爲神經網絡的非線性和過擬合導致的,但Goodfellow卻給出了更爲準確的解釋,即對抗樣本誤分類是因爲模型的線性性質導致的,說白了就是因爲wTxwTx存在點乘,當xx的每一個維度上都發生改變x˜=x+ηx~=x+η,就會累加起來在點乘的結果上附加上一個比較大的和wTx˜=wTx+wTηwTx~=wTx+wTη,而這個值可能就改變了預測結果。例如[4]中給出的一個例子,假設現在用邏輯迴歸做二分類,輸入向量是x=[2,1,3,2,2,2,1,4,5,1]x=[2,−1,3,−2,2,2,1,−4,5,1],權重向量是w=[1,1,1,1,1,1,1,1,1,1]w=[−1,−1,1,−1,1,−1,1,1,−1,1],點乘結果是-3,類預測爲1的概率爲0.0474,假如將輸入變爲xad=x+0.5w=[1.5,1.5,3.5,2.5,2.5,1.5,1.5,3.5,4.5,1.5]xad=x+0.5w=[1.5,−1.5,3.5,−2.5,2.5,1.5,1.5,−3.5,4.5,1.5],那麼類預測爲1的概率就變成了0.88,就因爲輸入在每個維度上的改變,導致了前後的結果不一致。

圖2 過/欠擬合導致對抗樣本(圖來源[3])

如果認爲對抗樣本是因爲模型的線性性質導致的,那麼是否能夠構造出一個方法來生成對抗樣本,即如何在輸入上加擾動,Goodfellow給出了一種構造方法fast gradient sign method[2],其中JJ是損失函數,再對輸入xx求導,θθ是模型參數,ϵϵ是一個非常小的實數。圖1中就是ϵ=0.007ϵ=0.007

η=ϵsign(xJ(θ,x,y))(1)η=ϵsign(▽xJ(θ,x,y))(1)

這個構造方法在[4]中有比較多的實例,這裏截取了兩個例子來說明,用imagenet圖片縮放到64*64來訓練一個一層的感知機,輸入是64*64*3,輸出是1000,權重是64*64*3*1000,訓練好之後取權重矩陣對應某個輸出類別的一行64*64*3,將這行還原成64*64圖片顯示爲下圖中第二列,再用公式1的方法從第一列的原始圖片中算出第三列的對抗樣本,可以看到第一行從預測爲狐狸變成了預測爲金魚,第二行變成了預測爲校車。

圖3 構造對抗樣本(圖來源[4])

實際上不是隻有純線性模型纔會出現這種情況,卷積網絡的卷積其實就是線性操作,因此也有預測不穩定的情況,relu/maxout甚至sigmoid的中間部分其實也算是線性操作。因爲可以自己構造對抗樣本,那麼就能應用這個性質來訓練模型,讓模型泛化能力更強。因而[2]給定了一種新的目標函數也就是下面的式子,相當於對輸入加入一些干擾,並且也通過實驗結果證實了訓練出來的模型更加能夠抵抗對抗樣本的影響。

J˜(θ,x,y)=αJ(θ,x,y)+(1α)J(θ,x+ϵsign(xJ(θ,x,y)))(2)J~(θ,x,y)=αJ(θ,x,y)+(1−α)J(θ,x+ϵsign(▽xJ(θ,x,y)))(2)

對抗樣本跟生成式對抗網絡沒有直接的關係,對抗網絡是想學樣本的內在表達從而能夠生成新的樣本,但是有對抗樣本的存在在一定程度上說明了模型並沒有學習到數據的一些內部表達或者分佈,而可能是學習到一些特定的模式足夠完成分類或者回歸的目標而已。公式1的構造方法只是在梯度方向上做了一點非常小的變化,但是模型就無法正確的分類。此外還觀察到一個現象,用不同結構的多個分類器來學習相同數據,往往會將相同的對抗樣本誤分到相同的類中,這個現象看上去是所有的分類器都被相同的變化所幹擾了。

2. 生成式對抗網絡GAN

14年Goodfellow提出Generative adversarial nets即生成式對抗網絡[5],它要解決的問題是如何從訓練樣本中學習出新樣本,訓練樣本是圖片就生成新圖片,訓練樣本是文章就輸出新文章等等。如果能夠知道訓練樣本的分佈p(x)p(x),那麼就可以在分佈中隨機採樣得到新樣本,大部分的生成式模型都採用這種思路,GAN則是在學習從隨機變量zz到訓練樣本xx的映射關係,其中隨機變量可以選擇服從正太分佈,那麼就能得到一個由多層感知機組成的生成網絡G(z;θg)G(z;θg),網絡的輸入是一個一維的隨機變量,輸出是一張圖片。如何讓輸出的僞造圖片看起來像訓練樣本,Goodfellow採用了這樣一種方法,在生成網絡後面接上一個多層感知機組成的判別網絡D(x;θd)D(x;θd),這個網絡的輸入是隨機選擇一張真實樣本或者生成網絡的輸出,輸出是輸入圖片來自於真實樣本pdatapdata或者生成網絡pgpg的概率,當判別網絡能夠很好的分辨出輸入是不是真實樣本時,也能通過梯度的方式說明什麼樣的輸入更加像真實樣本,從而通過這個信息來調整生成網絡。從而GG需要儘可能的讓自己的輸出像真實樣本,而DD則儘可能的將不是真實樣本的情況分辨出來。下圖左邊是GAN算法的概率解釋,右邊是模型構成。

圖4 GAN算法框圖(圖來源[6])

GAN的優化是一個極小極大博弈問題,最終的目的是generator的輸出給discriminator時很難判斷是真實or僞造的,即極大化DD的判斷能力,極小化將GG的輸出判斷爲僞造的概率,公式如下。論文[5]中將下面式子轉化成了Jensen-shannon散度的形式證明了僅當pg=pdatapg=pdata時能得到全局最小值,即生成網絡能完全的還原出真實樣本分佈,並且證明了下式能夠收斂。(算法流程論文講的很清楚,這裏就不說了,後面結合代碼一起解釋。)

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))](3)minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))](3)

以上是關於最基本GAN的介紹,最開始我看了論文後產生了幾個疑問,1.爲什麼不能直接學習GG,即直接學習一個zz到一個xx?2.GG具體是如何訓練的?3.在訓練的時候zzxx是一一對應關係嗎?在對代碼理解之後大概能夠給出一個解釋。

3. 代碼解釋

這部分主要結合tensorflow實現代碼[7]、算法流程和下面的變化圖[5]解釋一下具體如何使用DCGAN來生成手寫體圖片。

下圖中黑色虛線是真實數據的高斯分佈,綠色的線是生成網絡學習到的僞造分佈,藍色的線是判別網絡判定爲真實圖片的概率,標x的橫線代表服從高斯分佈x的採樣空間,標z的橫線代表服從均勻分佈z的採樣空間。可以看出GG就是學習了從z的空間到x的空間的映射關係。

圖5 GAN運行時各個概率分佈圖(圖來源[5])

a.起始情況

DD是一個卷積神經網絡,變量名是D,其中一層構造方式如下。

1
2
3
4
5
6
7
8
w = tf.get_variable('w', [4, 4, c_dim, num_filter],
initializer=tf.truncated_normal_initializer(stddev=stddev))
dconv = tf.nn.conv2d(ddata, w, strides=[1, 2, 2, 1], padding='SAME')
biases = tf.get_variable('biases', [num_filter],
initializer=tf.constant_initializer(0.0))
bias = tf.nn.bias_add(dconv, biases)
dconv1 = tf.maximum(bias, leak*bias)
...

GG是一個逆卷積神經網絡,變量名是G,其中一層構造方式如下。

1
2
3
4
5
6
7
8
9
10
w = tf.get_variable('w', [4, 4, num_filter, num_filter*2],
initializer=tf.random_normal_initializer(stddev=stddev))
deconv = tf.nn.conv2d_transpose(gconv2, w,
output_shape=[batch_size, s2, s2, num_filter],
strides=[1, 2, 2, 1])
biases = tf.get_variable('biases', [num_filter],
initializer=tf.constant_initializer(0.0))
bias = tf.nn.bias_add(deconv, biases)
deconv1 = tf.nn.relu(bias, name=scope.name)
...

GG的網絡輸入爲一個zdimzdim維服從-1~1均勻分佈的隨機變量,這裏取的是100.

1
2
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim])
.astype(np.float32)

DD的網絡輸入是一個batch的64*64的圖片,既可以是手寫體數據也可以是GG的一個batch的輸出。

這個過程可以參考上圖的a狀態,判別曲線處於不夠穩定的狀態,兩個網絡都還沒訓練好。

b.訓練判別網絡

判別網絡的損失函數由兩部分組成,一部分是真實數據判別爲1的損失,一部分是GG的輸出self.G判別爲0的損失,需要優化的損失函數定義如下。

1
2
3
4
5
6
7
8
9
self.G = self.generator(self.z)
self.D, self.D_logits = self.discriminator(self.images)
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
self.D_logits, tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
self.D_logits_, tf.zeros_like(self.D_)))
self.d_loss = self.d_loss_real + self.d_loss_fake

然後將一個batch的真實數據batch_images,和隨機變量batch_z當做輸入,執行session更新DD的參數。

1
2
3
4
5
6
# update discriminator on real
d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate,
beta1=FLAGS.beta1).minimize(d_loss, var_list=d_vars)
...
out1 = sess.run([d_optim], feed_dict={real_images: batch_images,
noise_images: batch_z})

這一步可以對比圖b,判別曲線漸漸趨於平穩。

c.訓練生成網絡

生成網絡並沒有一個獨立的目標函數,它更新網絡的梯度來源是判別網絡對僞造圖片求的梯度,並且是在設定僞造圖片的label是1的情況下,保持判別網絡不變,那麼判別網絡對僞造圖片的梯度就是向着真實圖片變化的方向。

1
2
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
self.D_logits_, tf.ones_like(self.D_)))

然後用同樣的隨機變量batch_z當做輸入更新

1
2
3
4
g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1)
.minimize(self.g_loss, var_list=self.g_vars)
...
out2 = sess.run([g_optim], feed_dict={noise_images:batch_z})

這一步可以對比圖c,pgpg的曲線在漸漸的向真實分佈靠攏。而網絡訓練完成之後可以看到pgpg的曲線與pdatapdata重疊在了一起,並且此時判別網絡已經難以區分真實與僞造,因此取值就固定在了1212

因而針對我之前的問題,2已經有了答案,針對1,爲什麼不能直接學習GG?這是因爲無法確定zzxx的一一對應關係,就像下圖,兩種對應關係,如果要肯定誰是對誰是錯,那麼就得加入一些先驗信息,甚至是直接對真實樣本的估計,那麼跟其他的方法不就一樣了麼。而問題3,在訓練的時候zzxx是一一對應關係嗎?我開始考慮這個問題是因爲不清楚是不是一個100維的noise變量就對應着一個手寫體變量圖片,但是現在考慮一下就應該明白在訓練的層面上不是一一對應的,甚至兩者在訓練DD的時候都是分開的,只是可能在分佈中會存在這樣一種對應關係而已。

圖6 z與x映射圖(圖來源[8])

4. 運行實例

這裏本來想用GAN來跑一個去噪的網絡,基於[7]的代碼改了一下輸入,從一個100維的noise向量變成了一張輸入圖片,同時將generator網絡的前面部分變成了卷積網絡,再連上原來的逆卷積,就成了一個去噪網絡,這裏我沒太多時間來細緻的調節網絡層數、參數等,就隨便試了一下,效果也不是特別的好。代碼在[9]中。首先我通過read_stl10.py對stl10數據集加上了均值爲0方差爲50的高斯噪聲,前後對比如下。

圖7 增加高斯噪聲前後對比

然後執行對抗網絡,會得到如下的去噪效果,從左到右分別是加了噪聲的輸入圖片,對應的generator網絡的輸出圖片,已經對應的乾淨圖片,效果不是特別好,輪廓倒是能學到一點,但是這個顏色卻沒學到。

圖8 去噪對比

5. 小結

剛開始搜資料的時候發現了對抗樣本,以爲跟對抗網絡有關係,就看了一下,後來看Goodfellow的論文時發現其實沒什麼關係,但是還是寫了一些內容,因爲這個東西的存在還是值得了解的,而對抗網絡這個想法真的太讚了,它將一個無監督問題轉化爲有監督,更加像一種learn的方式來學習數據應該是如何產生,而不是find的方式來找某些特徵,但是訓練也是一個難題,從我的經驗來看,特別容易過擬合,而且確實有一種對抗的感覺在裏面,因爲generator的輸入時好時壞,總的來說是個很棒的算法,非常期待接下來的研究。

6. 引用

[1] Intriguing properties of neural networks

[2] EXPLAINING AND HARNESSING ADVERSARIAL EXAMPLES

[3] Adversarial Examples

[4] Breaking Linear Classifiers on ImageNet

[5] Generative Adversarial Nets

[6] Quick introduction to GANs

[7] carpedm20/DCGAN-tensorflow

[8] Generative Adversarial Nets in TensorFlow (Part I)

[9] chenrudan/deep-learning/denoise_dcgan/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章