最近看了一些GAN的資料,把自己易混淆的內容做一個總結
生成式模型
我們以往通常接觸到的深度學習模型一般都是些判別模型,即通過訓練樣本訓練模型,然後利用模型對新樣本進行判別或預測。判別模型體現了深度學習的學習能力,然而,人工智能的強大,不應只有從已知中學習,還應該有創造能力,才就真正有趣。而生成式模型所體現的就是深度學習的創造能力。與判別式模型的工作流程恰恰相反,生成式模型是根據規則生成新的樣本。
生成式模型,主要包括變分自編碼器(VAE)和生成式對抗網絡(GAN)這兩種思路。VAE基於貝葉斯推理,其目的是潛在地建模,從模型中採樣新的數據。GAN是利用博弈論思想,以求得達到納什均衡(如何通俗的理解納什均衡點?)的判別器網絡(D)和生成器網絡(G)。
GAN架構
GAN的直觀理解,可以想象一個名畫僞造者想僞造一幅達芬奇的畫作,開始時僞造技術不精,但他將自己的一些贗品和達芬奇的作品混在一起,請一個藝術鑑賞家進行真實性評估,並向僞造者反饋真僞程度。僞造者根據反饋,改進自己的贗品。隨着時間的推移,那麼造假者的造假能力越來卻強,鑑賞家的能力也越來越強。而贗品,則越來越像真畫。
以上,便是GAN的原理。一個造假者G,一個鑑賞家D。他們訓練的目的都是爲了打敗對方。
下圖是我從書中截取的GAN架構圖,簡單明瞭。
GAN的損失函數
假設x表示圖像,D(x)表示判別網絡,是一個二元分類器,那麼它的輸出即爲圖片x來自訓練數據(而不是產生網絡輸出的假圖片)的概率。對於產生網絡,首先定義從標準正態分佈種採樣的數據z,則G(z)表示的是將向量z映射到空間的生成器函數。G的目標是估計訓練數據的分佈(Pdate)以生成假樣本。因此D(G(z))是產生網絡G的輸出是真實圖像的概率。即判別網絡D和產生網絡G在做一個極大極小的博弈,其中D試圖最大化它正確分辨真假數據(logD(x))的概率,而G試圖最小化D預測其輸出是假的概率(log(1-d(G(x))))。
Pytorch實現一個GAN
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 5 # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15 # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
print(PAINT_POINTS)
# show our beautiful painting range
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.legend(loc='upper right')
plt.show()
def artist_works(): # painting from the famous artist (real target)
a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
paintings = torch.from_numpy(paintings).float()
return paintings
G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)
D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
plt.ion() # something about continuous plotting
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_paintings = G(G_ideas) # fake painting from G (random ideas)
prob_artist0 = D(artist_paintings) # D try to increase this prob
prob_artist1 = D(G_paintings) # D try to reduce this prob
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
if step % 50 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.01)
plt.ioff()
plt.show()
參考資料:《Python深度學習 基於pytorch》,《深度學習與圖像識別 原理與實踐》,莫煩Python