00GANs

一、什麼是GAN?

  2014年,IanGoodfellow和他在蒙特利爾大學的同事發表了一篇令人驚歎的論文(GANs),提出了GANs(生成式對抗網絡)。 他們通過創新性地組合計算圖和博弈論,展示了給定足夠的建模能力,兩個相互對抗的模型能夠通過普通的反向傳播進行共同訓練.
  模型扮演了兩個不同的(確切地說,是對抗的)的角色。 給定一些真實數據集R,G是生成器,試圖創建看起來像真實數據的假數據,而D是判別器,從真實數據集或G中獲得數據並標記差異。Goodfellow給了一個很貼切的比喻,G像一夥努力用他們的輸出匹配真實圖畫的僞造者,而D是一幫努力鑑別差異的偵探。(唯一的不同是,僞造者G永遠不會看到原始數據 –而只能看到D的判斷。他們是一夥盲人騙子)。
在這裏插入圖片描述
  理想狀態下,D和G將隨着時間的推移而變得更好,直到G真正變成了原始數據的“僞造大師”,而D則徹底迷失,“無法分辨真假”。

二、五十行Pytorch搭建GAN

  這種強大的技術似乎需要大量的代碼纔可以,但是使用PyTorch,我們實際上可以在50行代碼下創建一個非常簡單的GAN。只需要考慮5個組件:
R:原始的、真正的數據;
I:進入生成器作爲熵源的隨機噪聲;
G:努力模仿原始數據的生成器;
D:努力將G從R中分辨出來的判別器;
訓練循環,我們在其中教G來迷惑D,教D鑑別G。
(1)R:在我們的例子中,我們將從最簡單的R- 一個鐘形曲線開始。 鐘形函數採用均值和標準差,並返回一個函數,該函數提供了使用這些參數的高斯分佈的正確形狀的樣本數據。在我們的示例代碼中,我們將使用均值4.0和標準差1.25。

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))).to(device)

(2)I:進入生成器的輸入也是隨機的,但是爲了使我們的工作更難一點,讓我們使用一個均勻分佈,而不是一個普通的分佈。這意味着我們的模型G不能簡單地移動/縮放輸入以複製R,而是必須以非線性方式重塑數據。

def get_generator_input_sampler():
    return lambda m, n: (torch.rand(m, n)).to(device)

(3)G:生成器是一個標準的前饋網絡 - 兩個隱藏層,三個線性映射。我們使用ELU(exponential linear unit,ELU)。 G將從I獲得均勻分佈的數據樣本,並以某種方式模仿來自R的正態分佈樣本。

#生成器模爲全連接層
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f
        
    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x

(4)D:鑑別器代碼與生成器G的代碼非常相似;具有兩個隱藏層和三個線性映射的前饋網絡。 它將從R或G獲取樣本,並將輸出介於0和1之間的單個標量,解釋爲“假”與“真”。

#鑑別器模型
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hiddsen_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))

(5)最後,訓練在兩種模式之間循環交替:首先在真實數據與假數據上用準確的標籤訓練D,; 然後用不準確的標籤訓練G來愚弄D。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章