GAN原理詳解

GAN的基本原理其實非常簡單,這裏以生成圖片爲例進行說明。假設我們有兩個網絡,G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:
在這裏插入圖片描述

  • G是一個生成圖片的網絡,它接收一個隨機的噪聲z,通過這個噪聲生成圖片,記做G(z)。
  • D是一個判別網絡,判別一張圖片是不是“真實的”。它的輸入參數是x,x代表一張圖片,輸出D(x)代表x爲真實圖片的概率,如果爲1,就代表100%是真實的圖片,而輸出爲0,就代表不可能是真實的圖片。

數學公式:
在這裏插入圖片描述
這裏主要理解是怎麼訓練網絡的。

  • 整個式子由兩項構成。x表示真實圖片,z表示輸入G網絡的噪聲,而G(z)表示G網絡生成的圖片。
  • D(x)表示D網絡判斷真實圖片是否真實的概率(因爲x就是真實的,所以對於D來說,這個值越接近1越好)。而D(G(z))是D網絡判斷G生成的圖片的是否真實的概率。
  • G的目的:上面提到過,D(G(z))是D網絡判斷G生成的圖片是否真實的概率,G應該希望自己生成的圖片“越接近真實越好”。也就是說,G希望D(G(z))儘可能得大,這時V(D, G)會變小。因此我們看到式子的最前面的記號是min_G。
  • D的目的:D的能力越強,D(x)應該越大,D(G(x))應該越小。這時V(D,G)會變大。因此式子對於D來說是求最大(max_D)

還是不太理解,怎麼做到網絡的訓練

假設我們已經有了G網絡和D網絡,以及real_data。那麼我們怎麼對GD訓練呢。
再一次迭代中:
minG\min\limits_{G}
1. 首先輸入g_input(符合G輸入的隨機噪聲),對G前向傳播獲得g_fake_data=G(g_input)
2. 然後輸入至D網絡對其進行鑑別dg_fake_describe = D(g_fake_data)
3. 此時損失定義爲g_error = criterion(dg_fake_describe,1)。此時對這個損失進行反向傳播,爲了讓這個減少這個損失,就會更新G網絡的參數,是它生成的數據越來越接近真實。(記住這裏這個損失也對D網絡進行了反向傳播,但是並不對D網絡的參數進行更新)
maxD\max\limits_{D}
4. 獲得真實標籤輸入d_real_input, 並對D網絡前向傳播獲得d_real_descirbe = D(d_real_input)。計算其損失d_real_error= criterion(d_real_descirbe, 1)。這個損失反向傳播的時候可以告訴D網絡這是真實標籤。
5. 獲得G網絡的生成的假數據g_fake_data ,並計算其輸入到D網絡的損失函數d_fake_error = criterion(D(g_input), 0),這個損失函數可以保證D網絡能識別出假的圖像。
(這裏雖然計算G網絡的反向傳播梯度,但是並不對G網絡進行參數進行更新)

直到最後D網絡認爲生成的假圖像接近真實==(即D(g_fake_data)=0.5)==,表明網絡訓練成功。
這樣就實現了兩個網絡的訓練。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章