WGAN:優化不完美的GAN

引言

​   前不久做了初代GAN的實驗,感受到了生成模型的強大,最近在看GAN的的變體WGAN感覺到了數學的強大,僅僅在原始的GAN上稍作修改就能達到不一樣的效果,真實的感受到了數學是的魅力,本次也是接着李宏毅老師的課件對WGAN的進行了一些整理。當然,非常推薦大家先看篇博文, 然後,讀一讀作者的paper,多讀幾遍,總會有醍醐灌頂的效果。不過,首先來看一下WGAN解決了那些問題,以便於我們很好理解其中的點。

we empirically show that WGANs cure the main training problems of GANs. In particular, training WGANs does not require maintaining a careful balance in training of the discriminator and the generator, and does not require a careful design of the network architecture either. The mode dropping phenomenon that is typical in GANs is also drastically reduced. One of the most compelling practical benefits of WGANs is the ability to continuously estimate the EM distance by training the discriminator to optimality. Plotting these learning curves is not only useful for debugging and hyper- parameter searches, but also correlate remarkably well with the observed sample quality.

原始GAN的問題

真實數據分佈和生成數據分佈很少重疊

​    原始GAN使用JS divergence來衡量PGP_GPdataP_{data}的有多像,但是在很多情況下PGP_GPdataP_{data}是不重疊的,如PGP_GPdataP_{data}和是高維空間的低緯流體,這使得它們之間即使存在有相互重疊的部分,也在很大程度上可以忽略。例如:二維空間中兩條相交的線段,它們在重疊在相交點的那一部分對於整體來說是微不足道的。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-281sV0KB-1573387985078)(Imgs/JS不太好.png)]

JS divergence本身的問題

​   如果PGP_GPdataP_{data}的不重疊,那麼我們的JS衡量相當於一個常數log2log2,而讓我們直觀上感覺,下圖中間部分會比較好一點比左面部分,因爲它使得兩個數據分佈更加的接近了,這也是我們想要的結果,但是我們使用JS來衡量數據分佈直接接近程度,卻在在它們不重疊的時候一直是一個常數,這使我們無法通過JS來判斷我們PGP_GPdataP_{data}的有多接近,我們失去了衡量彼此接近程度的標準。(爲什麼等於log2)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-RZjw23Gm-1573387985080)(JS problem.png)]

新的衡量距離

EM距離的原理

​   所以,WGAN直接提出一種新的衡量距離的方法:Earth Mover’s Distance(推土機距離),我們把數據分佈P作爲一方土,把另一個數據分佈Q作爲目標。我們把P通過推土機移動到目標Q上。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-XkeugO9U-1573387985080)(/home/gavin/Documents/WGAN/Imgs/EM.png)]

​   那麼,我們把數據分佈P移動成數據分佈Q存在很多種方案,我們可以使最大移動量或者最小的移動量,在這裏我們限定使用最小移動量來作爲衡量標準,即我們希望最小的移動量達到最優的擬合結果。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-onheklow-1573387985081)(Imgs/EM.png)]

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-EOigrCSY-1573387985082)(Imgs/Best move.png)]

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-7xsxRPhP-1573387985083)(Imgs/矩陣移動.png)]

爲什麼EM距離好?

​   我們首先來有以下兩個概率密度函數,我們通過縮小θ\theta,使得θ\theta作爲我們的距離衡量標準。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-TijD7ue4-1573387985083)(Imgs/移動距離.png)]

​   下面是幾種距離衡量的標準,我們可以直觀的觀察得到,當θ\theta在不斷的縮減過程中,EM距離是一個連續的值,儘管數據分佈不重疊,但是,在JS和KL等距離下無法得到合適的測度,因爲它們不連續,並且是一個跳躍的值。
W(P0,Pθ)=θJ(P0,Pθ)={log2 if θ00 if θ=0KL(PθP0)=KL(P0Pθ)={+ if θ00 if θ=0 and δ(P0,Pθ)={1 if θ00 if θ=0 \begin{array}{l}{W\left(\mathbb{P}_{0}, \mathbb{P}_{\theta}\right)=|\theta|} \\ {J\left(\mathbb{P}_{0}, \mathbb{P}_{\theta}\right)=\left\{\begin{array}{ll}{\log 2} & {\text { if } \theta \neq 0} \\ {0} & {\text { if } \theta=0}\end{array}\right.} \\ {K L\left(\mathbb{P}_{\theta} \| \mathbb{P}_{0}\right)=K L\left(\mathbb{P}_{0} \| \mathbb{P}_{\theta}\right)=\left\{\begin{array}{ll}{+\infty} & {\text { if } \theta \neq 0} \\ {0} & {\text { if } \theta=0}\end{array}\right.} \\ {\text { and } \delta\left(\mathbb{P}_{0}, \mathbb{P}_{\theta}\right)=\left\{\begin{array}{ll}{1} & {\text { if } \theta \neq 0} \\ {0} & {\text { if } \theta=0}\end{array}\right.}\end{array}
​ 下面我直接使用原paper中的一部分,更加直觀來比較JS和EM距離。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-prhxVsZH-1573387985084)(Imgs/JS and EM.png)]
Figure 1: These plots show ρ(Pθ, P0) as a function of θ when ρ is the EM distance (left plot) or the JS divergence (right plot). The EM plot is continuous and provides a usable gradient everywhere. The JS plot is not continuous and does not provide a usable gradient.

WGAN

從EM過渡到WGAN

​   所以,我們基於EM距離提出了WGAN,我們提出了有約束的判別器(滿足1-Lipschitz),而Lipschitz連續條件限制了一個連續函數的最大局部變動幅度。然後,最大化V(G,D)V(G, D),而我們爲了滿足約束條件採取了一種非常暴力的方法"weight clipping", 原論文中也說了這是一種非常槽糕的方式去使得判別器滿足這個約束。權重裁剪的方式也很簡單,只需要在反向傳播把更新的權重強制夾到一個範圍內就可以。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-1Y5sCWH9-1573387985085)(Imgs/WGAN.png)]

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-5dJfRqzO-1573387985086)(Imgs/weight clipping.png)]

​   當然這樣的權重裁剪,適當的值也是非常重要的,所以原論文給出了權重裁剪過大或者過小時出現的問題,如下:

Weight clipping is a clearly terrible way to enforce a Lipschitz constraint. If the clipping parameter is large, then it can take a long time for any weights to reach their limit, thereby making it harder to train the critic till optimality. If the clipping is small, this can easily lead to vanishing gradients when the number of layers is big, or batch normalization is not used (such as in RNNs).

WGAN算法

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-k39G1c28-1573387985086)(./Imgs/算法原理.png)]

​ 所以,可以看出對比原始的GAN,WGAN只改了以下四個部分:

  1. 判別器最後一層去掉sigmoid

  2. 生成器和判別器的loss不取log

  3. 每次更新判別器的參數之後把它們的絕對值截斷到不超過一個固定常數c

  4. 不要用基於動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行

對於上述第四點原paper中也做了相關的闡述,如下:

​ Finally, as a negative result, we report that WGAN training becomes unstable attimes when one uses a momentum based optimizer such as Adam [8] (with β1 > 0)on the critic, or when one uses high learning rates. Since the loss for the critic isnonstationary, momentum based methods seemed to perform worse. We identifiedmomentum as a potential cause because, as the loss blew up and samples got worse,the cosine between the Adam step and the gradient usually turned negative. Theonly places where this cosine was negative was in these situations of instability. Wetherefore switched to RMSProp [21] which is known to perform well even on verynonstationary problems [13]

而在真實的梯度更新的過程中,我們也能從下圖中看到不同生成器在最優的判別器下的梯度更新情況

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-jtNPhSoX-1573387985087)(Imgs/梯度更新.png)]
Figure 2: Optimal discriminator and critic when learning to differentiate two Gaussians.As we can see, the discriminator of a minimax GAN saturates and results in vanishing gradients. Our WGAN critic provides very clean gradients on all parts of the space.

Pytorch 復現

代碼改動

​   本次復現只在上一個版本上進行了局部的改動,這也如前面所說只需改動原始GAN算法的四個位置即可,改動結果如下:

  1. 判別器最後一層去掉sigmoid

    class NetD(nn.Module):
        """
        構建一個判別器,相當與一個二分類問題, 生成一個值
        """
    
        def __init__(self, opt):
            super(NetD, self).__init__()
    
            ndf = opt.ndf
            self.main = nn.Sequential(
                # 輸入96*96*3
                nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
    
                # 輸入32*32*ndf
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, True),
    
                # 輸入16*16*ndf*2
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, True),
    
                # 輸入爲8*8*ndf*4
                nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, True),
    
                # 輸入爲4*4*ndf*8
                nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),
    
                # 去除最後一層的sigmoid
                # nn.Sigmoid()
    
            )
    
        def forward(self, x):
            return self.main(x)
    
  2. 生成器和判別器的loss不取log

    生成器loss

    G_loss = -1 * (t.mean(netd(gen_img)))
    

    判別器loss

    D_loss = -1 * t.mean(netd(real_img)) + t.mean(netd(fake_img))
    
  3. 每次更新判別器的參數之後把它們的絕對值截斷到不超過一個固定常數c

     for p in netd.parameters():
                        p.data.clamp_(-opt.clip_value, opt.clip_value)  # opt.clip_value = 0.01
    
  4. 不要用基於動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行

    optimizer_g = t.optim.SGD(netg.parameters(), lr=opt.lr1)
    optimizer_d = t.optim.SGD(netd.parameters(), lr=opt.lr2)
    

數據分析

  • 權重分佈

    ​   實驗過程中對判別器的權重進行收集,如下圖,其中weight 1表示判別器第一層的卷積權重分佈,以此類推,我在這裏取了四個層的權重來進行對比,可以看出經過權重裁剪之後的權重分佈偏向兩側。總感覺這樣太暴力了,不過這種暴力裁剪的方法已經在WGAN-GP中已經得到解決了,最近在跟進。

    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-fQC1z7tV-1573387985088)(Imgs/weiht1.png)] [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-tm6PVy6D-1573387985089)(Imgs/weight5.png)]
    weight 1 weight 5
    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Cxm2LcMq-1573387985090)(Imgs/weight9.png)] [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-2npHdyhl-1573387985090)(Imgs/weight11.png)]
    weight9 weight 11
  • Loss

    D_loss:

    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-JvnLnTSD-1573387985091)(Imgs/D_loss.png)]

    G_loss:

    [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-GFFaOmeY-1573387985091)(Imgs/G_loss.png)]

實驗效果

​   下圖是經過8000個epoch的效果,效果不是太好,可能是訓練次數太少的效果,畢竟8000個epoch對於煉丹來說,還是差點意思,但是,我們主要是學習思想和方法,當然如果能做到好的實驗效果也不能偷懶,哈哈!

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-noJB6nQB-1573387985092)(Imgs/7999.png)]

結論

WGAN從數學的角度,層層分析原始GAN所存在的問題,並且提出一種新的測度,這使得GAN更加具有魯棒性和穩定性,自己在學習過程中也深感數學之偉大。如果在翻閱本博客時,看到錯誤的地方請即使指出,Pytorch代碼我已經放到本人GitHub上,鏈接在下面參考文獻中。與君共勉。

參考文獻

令人拍案叫絕的Wasserstein GAN

Wasserstein GAN

WGAN的來龍去脈

W-GAN系 (Wasserstein GAN、 Improved WGAN)

pytorch WGAN

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章