【閱讀筆記】Improved Training of Wasserstein GANs

Improved Training of Wasserstein GANs

Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[C]//Advances in Neural Information Processing Systems. 2017: 5767-5777.
GitHub: https://github.com/igul222/improved_wgan_training

Abstract

GAN雖然是個強有力的生成模型,但是訓練不穩定的缺點影響它的使用。剛剛提出的 Wasserstein GAN (WGAN) 使得 GAN 的訓練變得穩定,但是有時也會產生很差的樣本和不收斂。我們發現這些問題的原因常常是因爲 weight clipping 來滿足 判別器(critic,os.坑,研究了半天才領會這個意思)的 Lipschitz constraint。我們把 weight clipping 轉化爲成 判別器 的梯度範數關於輸入的懲罰。我們的方法優於 standard WGAN 和大部分的 GAN 的變種。

Introduction

Generative adversarial networks

Formally, the game between the generator G and the discriminator D is the minimax objective:
minGmaxDExpr[logD(x)]+Ex^pg[log(1D(x^))]min_Gmax_DE_{x\sim p_r}[logD(x)]+E_{\hat{x}\sim p_g}[log(1-D(\hat{x}))]

In practice, the generator is instead trained to maximize Ex^pg[log(D(x^))]E_{\hat{x}\sim p_g}[log(D(\hat{x}))]。因爲這樣可以規避當判別器飽和時的梯度消失。

Wasserstein GANs

The WGAN value function is constructed using the Kantorovich-Rubinstein duality to obtain
minGmaxDDExpr[D(x)]Ex^pg[D(x^)]min_Gmax_{D\in\mathscr{D}}E_{x\sim p_r}[D(x)]-E_{\hat{x}\sim p_g}[D(\hat{x})]

其中D\mathscr{D}是 1-Lipschitz functions。爲了使判別器滿足 k-Lipschitz 限制,需要將權重固定在[c,c][-c,c],k是由cc和模型結構所決定。

Difficulties with weight constraints

如下圖所示,發現進行 weight clipping 有兩個特點,一是會使得權重集中在所設範圍的兩端,二是會很容易造成梯度爆炸或梯度消失。這是因爲判別器要滿足 Lipschitz 條件,但是判別器的目標是使得真假樣本判別時差別越大越好,經過訓練後,權值的絕對值就集中在最大值附近了。
在這裏插入圖片描述

Gradient penalty

Algorithm 1 WGAN with gradient penalty. We use default values of λ=10\lambda=10, ncritic=5n_{critic}=5, KaTeX parse error: Expected 'EOF', got '\apha' at position 1: \̲a̲p̲h̲a̲=0.0001, β1=0\beta_1=0, β2=0.9\beta_2=0.9.
Require: The gradient penalty coefficient λ\lambda, the number of critic iterations per generator iteration ncriticn_critic, the batch size mm, Adam hyperparameters α,β1,β2\alpha,\beta_1,\beta_2.
Require: initial critic parameters w0w_0, initial generator parameters θ0\theta_0.

  • while θ\theta has not converged do
    • for t=1,...,ncritict=1, ..., n_{critic} do
      • for i=1,...,mi = 1, ..., m do
        • Sample real data xPrx\sim P_r, latent variable zp(z)z\sim p(z), a random number ϵU[0,1]\epsilon\sim U[0, 1].
        • x~Gθ(z)\tilde{x}\leftarrow G_{\theta}(z)
        • x^ϵx+(1ϵ)x^\hat{x}\leftarrow\epsilon x + (1 −\epsilon)\hat{x}
        • L(i)Dw(x)Dw(x~)+λ(x^Dw(x^)21)2L^{(i)}\leftarrow D_w(x) − D_w(\tilde{x}) + \lambda(||\nabla_{\hat{x}}D_w(\hat{x})||_2-1)^2
      • end for
      • wAdam(w1mi=1mL(i),w,α,β1,β2)w\leftarrow Adam(\nabla_w\frac{1}{m}\sum_{i=1}^mL^(i), w, \alpha, \beta_1, \beta_2)
    • end for
    • Sample a batch of latent variables {z(i)}i=1mp(z)\{z^{(i)}\}^m_{i=1}\sim p(z).
    • θAdam(θ1mI=1mDw(Gtheta(z)),θ,α,β1,β2)\theta\leftarrow Adam(\nabla_{\theta}\frac{1}{m}\sum_{I=1}^m−D_w(G_{theta}(z)), θ, \alpha, \beta_1, \beta_2)
  • end while

WGAN-GP 的創新點在與優化了代價函數
L=ExprDw(x)Expg[Dw(x~)]+λEx^px^[x^Dw(x^)21)2]L= E_{x\sim p_r}D_w(x) − E_{x\sim p_g}[D_w(\tilde{x})] + \lambda E_{\hat{x}\sim p_{\hat{x}}}[||\nabla_{\hat{x}}D_w(\hat{x})||_2-1)^2]

對權重增加懲罰項,使得在原始數據和生成數據中間地帶的權重的儘量小,相當於把 WGAN 的硬閾值轉化爲了軟閾值。
在這裏插入圖片描述

Experiments

在這裏插入圖片描述
在這裏插入圖片描述

Conclusion

從實驗上來看效果好於其他 GAN 方法,但是看其他資料說不一定好於WGAN,以後有空實驗一下看看效果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章