Pytorch實現一個簡單的生成對抗網絡GAN

最近看了一些GAN的資料,把自己易混淆的內容做一個總結

生成式模型

        我們以往通常接觸到的深度學習模型一般都是些判別模型,即通過訓練樣本訓練模型,然後利用模型對新樣本進行判別或預測。判別模型體現了深度學習的學習能力,然而,人工智能的強大,不應只有從已知中學習,還應該有創造能力,才就真正有趣。而生成式模型所體現的就是深度學習的創造能力。與判別式模型的工作流程恰恰相反,生成式模型是根據規則生成新的樣本。
        生成式模型,主要包括變分自編碼器(VAE)生成式對抗網絡(GAN)這兩種思路。VAE基於貝葉斯推理,其目的是潛在地建模,從模型中採樣新的數據。GAN是利用博弈論思想,以求得達到納什均衡(如何通俗的理解納什均衡點?)的判別器網絡(D)和生成器網絡(G)。

GAN架構

        GAN的直觀理解,可以想象一個名畫僞造者想僞造一幅達芬奇的畫作,開始時僞造技術不精,但他將自己的一些贗品和達芬奇的作品混在一起,請一個藝術鑑賞家進行真實性評估,並向僞造者反饋真僞程度。僞造者根據反饋,改進自己的贗品。隨着時間的推移,那麼造假者的造假能力越來卻強,鑑賞家的能力也越來越強。而贗品,則越來越像真畫。

        以上,便是GAN的原理。一個造假者G,一個鑑賞家D。他們訓練的目的都是爲了打敗對方。

        下圖是我從書中截取的GAN架構圖,簡單明瞭。

GAN的損失函數

     

  假設x表示圖像,D(x)表示判別網絡,是一個二元分類器,那麼它的輸出即爲圖片x來自訓練數據(而不是產生網絡輸出的假圖片)的概率。對於產生網絡,首先定義從標準正態分佈種採樣的數據z,則G(z)表示的是將向量z映射到空間的生成器函數。G的目標是估計訓練數據的分佈(Pdate)以生成假樣本。因此D(G(z))是產生網絡G的輸出是真實圖像的概率。即判別網絡D和產生網絡G在做一個極大極小的博弈,其中D試圖最大化它正確分辨真假數據(logD(x))的概率,而G試圖最小化D預測其輸出是假的概率(log(1-d(G(x))))。

 

Pytorch實現一個GAN

      

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt



# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001           # learning rate for generator
LR_D = 0.0001           # learning rate for discriminator
N_IDEAS = 5             # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15     # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
print(PAINT_POINTS)
# show our beautiful painting range
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.legend(loc='upper right')
plt.show()


def artist_works():     # painting from the famous artist (real target)
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
    paintings = torch.from_numpy(paintings).float()
    return paintings

G = nn.Sequential(                      # Generator
    nn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)

D = nn.Sequential(                      # Discriminator
    nn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()   # something about continuous plotting

for step in range(10000):
    artist_paintings = artist_works()           # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas
    G_paintings = G(G_ideas)                    # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)          # D try to increase this prob
    prob_artist1 = D(G_paintings)               # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)      # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()

 

參考資料:《Python深度學習 基於pytorch》,《深度學習與圖像識別 原理與實踐》,莫煩Python

 

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章