參考:https://www.cnblogs.com/bonelee/p/9166084.html
GAN框架
對抗式生成網絡GAN(Generative Adversarial Net),是一個非常流行的生成式模型。 GAN 有兩個網絡,一個是 生成器generator,用來生成僞樣本;一個是判別器 discriminator,用於判斷樣本的真假。通過兩個網絡互相博弈和對抗來達到最好的生成效果,示意圖如下:
首先介紹KL散度(KL divergence),用於衡量兩種概率分佈的相似程度,數值越小,表示兩種概率分佈越接近。離散的概率分佈:
連續的概率分佈:
設真實樣本集服從分佈,其中是一個真實樣本。生成器產生的分佈設爲,是生成器G的參數,通過優化使得和儘可能接近,也就是生成的圖片與真實分佈一致。
從真實數據分佈裏面取樣個點,,根據給定的參數可以計算出生成這個樣本數據的似然爲:
爲最大化似然的結果:
是隨機噪聲,服從正態分佈或均勻分佈,通過生成器生成圖片,
其中爲示性函數:
這樣無法通過最大似然對生成器參數進行求解。因此採用判別器D分類與產生的誤差來取代極大似然估計。
下面是訓練判別器的示意圖,此時的生成器的權重被固定,真實圖片和生成圖片都會輸入到判別器中:
下面是訓練生成器的示意圖,此時的判別器的權重被固定,生成圖片輸入到判別器中:
誤差
對於判別器來說,希望能夠正確地分類真樣本和假樣本,所以需要最小化分類誤差,也可以說是最大化獎勵,這裏獎勵就是交叉熵的負數形式:
對於上述的獎勵函數,需要優化判別器D和生成器G兩個參數,此時可以採用的方法是固定一個優化另外一個。對於D來說,希望最大加獎勵V(D,G),對於生成器來說,希望最小化獎勵V(D,G),也就是說希望生成的圖片能騙過生成器。此時的優化目標爲:
當博弈達到納什平衡(Nash equilibrium)時,i.e.,,,G是最優的。
訓練過程
在一個epoch中,首先使用真實圖片和generator生成的假圖片來訓練discriminator是否能判別真假,即是二分類問題。之後只用generator生成假圖片在discriminator的誤差來訓練generator。
GAN優缺點
優點:
- 抽樣和生成很簡單直接。
- 訓練不涉及最大似然估計。
- 生成器不接觸真實樣本,對過擬合具有健壯性。
- 實驗上,GAN擅長捕獲分佈的模式。
缺點:
- 生成樣本的概率分佈是隱式的,無法直接計算概率。因此vanilla GANs只能用於生成樣本。
- 訓練不收斂。SGD通常在確定的條件下找到最有參數,可能不會收斂到一個Nash平衡點。
- mode-collapse模式坍塌。一般出現在GAN訓練不穩定的時候,具體表現爲生成出來的結果非常差,但是即使加長訓練時間後也無法得到很好的改善。
具體原因可以解釋如下:GAN採用的是對抗訓練的方式,G的梯度更新來自D,所以G生成的好不好,需要憑藉D的判斷。但是如果某一次G生成的樣本可能並不是很真實,但是D給出了正確的評價,或者是G生成的結果中一些特徵得到了D的認可,這時候G生成的結果是正確的,那麼接下來通過D生成的樣本還會得到高的評價,實際上G生成的並不怎麼樣,但是他們兩個就這樣自我欺騙下去了,導致最終生成結果缺失一些信息,特徵不全。
GAN生成MNIST數據集
以下使用GAN來生成手寫數字。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
z_dimension = 100 # the dimension of noise tensor
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.dis(x)
return x
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(z_dimension, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, x):
x = self.gen(x)
return x
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
def to_img(x):
out = 0.5 * (x + 1) # 將x的範圍由(-1,1)伸縮到(0,1)
out = out.view(-1, 1, 28, 28)
return out
D = Discriminator().to('cpu')
G = Generator().to('cpu')
criterion = nn.BCELoss()
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
D.train()
G.train()
all_D_loss = 0.
all_G_loss = 0.
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to('cpu'), targets.to('cpu')
num_img = targets.size(0)
real_labels = torch.ones_like(targets, dtype=torch.float)
fake_labels = torch.zeros_like(targets, dtype=torch.float)
inputs_flatten = torch.flatten(inputs, start_dim=1)
# Train Discriminator
real_outputs = D(inputs_flatten)
D_real_loss = criterion(real_outputs, real_labels)
z = torch.randn((num_img, z_dimension)) # Random noise from N(0,1)
fake_img = G(z) # Generate fake images
fake_outputs = D(fake_img.detach())
D_fake_loss = criterion(fake_outputs, fake_labels)
D_loss = D_real_loss + D_fake_loss
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# Train Generator
z = torch.randn((num_img, z_dimension))
fake_img = G(z)
G_outputs = D(fake_img)
G_loss = criterion(G_outputs, real_labels)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
all_D_loss += D_loss.item()
all_G_loss += G_loss.item()
print('Epoch {}, d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format
(epoch, all_D_loss/(batch_idx+1), all_G_loss/(batch_idx+1),
torch.mean(real_outputs), torch.mean(fake_outputs)))
# Save generated images for every epoch
fake_images = to_img(fake_img)
save_image(fake_images, 'MNIST_FAKE/fake_images-{}.png'.format(epoch + 1))
for epoch in range(40):
train(epoch)
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')
運行40輪得到的結果:
在訓練完之後,可以得到generator的參數,可以將其單獨剝離出來進行圖像生成。此時,給generator任意生成的符合先驗分佈的噪聲向量,就會生成對應的圖片:
import torch
import torch.nn as nn
from torchvision.utils import save_image
z_dimension = 100 # the dimension of noise tensor
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(z_dimension, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, x):
x = self.gen(x)
return x
def to_img(x):
out = 0.5 * (x + 1)
out = out.view(-1, 1, 28, 28)
return out
G = Generator().to('cpu')
G.load_state_dict(torch.load('./generator.pth'))
def generate_synthetic_images(num_img):
G.eval()
z = torch.randn((num_img, z_dimension))
fake_img = G(z)
fake_images = to_img(fake_img)
print(fake_img)
save_image(fake_images, 'MNIST_GEN/synthetic_images.png')
if __name__ == '__main__':
generate_synthetic_images(100)