DCGAN
1.什麼是GAN
GAN是一個框架,讓深度模型可以學習到數據的分佈,從而通過數據的分佈生成新的數據(服從同一分佈)。
其由一個判別器和一個生成器構成,生成器負責生成“仿造數據”,判別器負責判斷“仿造數據”的質量。兩者一起進化,導致造假貨和識別假貨的兩個模型G/D都能有超強的造假和識別假貨的能力。
最終訓練達到類似納什均衡的平衡狀態,就是分辨器已經分辨不出真假,其分別真假的成功率只有50%(和瞎猜沒有區別)。
假設原數據分佈爲x(可以是一張真實圖片等多維數據),判別器D(),隨機變量Z,生成器爲G()。D(x)生成一個標量代表x來自真實分佈的概率。Z是一個隨機噪聲,G(Z)代表隨機噪聲Z(也稱爲隱空間向量)到真實分佈P_data的映射。G(Z)的生成數據的概率分佈記作P_G.
所以D(G(z))就是一個標量代表其生成圖片是真實圖片的概率
,同時D和G在玩一個你最小(G)我最大(D)的遊戲。D想把自己分別真假圖片x的成功率最大化
logD(x)
G想把造假圖片z和真實圖片x的差距最小化
log(1-D(G(x))。
總目標函數(loss function)可以寫成:
2.什麼是DCGAN
DCGAN是GAN的一個擴展,卷積網絡做判別器,反捲積做生成器。
判別器通過大幅步的卷積網絡、批量正則化、LeakyRelu激活函數構成。輸入一個3*64 *64的圖片,輸出一個真假概率值。
生成器由一個反捲積網絡、批量正則化、Relu激活函數構成,通過輸入一個隱變量z(如標準正態分佈)。同時輸出一個3*64 *64的圖片。
同時《 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks》的原作者還給出如何設置優化器(optimizers),如何計算損失函數,如何初始化模型weights等技巧。
初始導入代碼如下:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
3.輸入設置
輸入參數設置
- dataroot - the path to the root of the dataset folder. We will talk more about the dataset in the next section
- workers - the number of worker threads for loading the data with the DataLoader
- batch_size - the batch size used in training. The DCGAN paper uses a batch size of 128
- image_size - the spatial size of the images used for training. This implementation defaults to 64x64. If another size is desired, the structures of D and G must be changed.
- nc - number of color channels in the input images. For color images this is 3
- nz - length of latent vector
- ngf - relates to the depth of feature maps carried through the generator
- ndf - sets the depth of feature maps propagated through the discriminator
- num_epochs - number of training epochs to run. Training for longer will probably lead to better results but will also take much longer
- lr - learning rate for training. As described in the DCGAN paper, this number should be 0.0002
- beta1 - beta1 hyperparameter for Adam optimizers. As described in paper, this number should be 0.5
- ngpu - number of GPUs available. If this is 0, code will run in CPU mode. If this number is greater than 0 it will run on that number of GPUs
# Root directory for dataset
dataroot = "data/celeba"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
4.數據
數據集用的是港中文的Celeb-A
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
#real_batch是一個列表
#第一個元素real_batch[0]是[128,3,64,64]的tensor,就是標準的一個batch的4D結構:128張圖,3個通道,64長,64寬
#第二個元素real_batch[1]是第一個元素的標籤,有128個label值全爲0
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
#這個函數能讓圖片顯示
#plt.show()
5.實現(Implementation)
5.1 參數初始化(Weight Initialization)
w初始化爲均值爲0,標準差爲0.02的正態分佈
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
5.2 生成器(Generator)
生成器G是構造一個由向量Z(隱空間)到真實數據空間的映射(map)
-
nz=100,z輸入時的長度
-
nc=3,輸出時的chanel,彩色是RGB三通道
-
ngf=64,指的是生成的特徵爲64*64
-
反捲積的函數爲:
ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
參數爲:1.輸入、2.輸出、3.核函數、4.卷積核步數、5.輸入邊填充、6.輸出邊填充、7.group、8.偏置、9.膨脹
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
#輸入100,輸出64*8,核函數是4*4
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
實例化生成器,初始化參數w
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
# Print the model
print(netG)
out:
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
5.3 判別器(Discriminator)
判別器D是一個二元分類器,判別輸入的圖片真假。通過輸入圖片進入一連串的卷積層中,經過卷積(Strided Convolution)、批量正則(BatchNorm)、LeakyReLu激活,最終通過Sigmoid激活函數輸出一個概率選擇。
以上的結構如有必要可以擴展更多的層,不過DCGAN的設計者通過實驗發現調整步幅的卷積層比池化的下采樣效果要好,因爲通過卷積網絡可以學習到自己的池化函數。同時批量正則化和leakly relu函數都可以提高梯度下降的質量,這些效果在同時訓練G和D時顯得更爲突出。
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
構建D,並初始化w方程,並且輸出模型的結構。
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)
# Print the model
print(netD)
out:
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
5.4 損失函數&優化器(loss&optimizer)
用Pytorch自帶的損失函數Binary Corss Entropy(BCELoss),其定義如下:
我們定義真圖片real爲1,假圖片fake爲0。同時設置兩個優化器optimizer。在本例中
都是adam優化器,其學習率是0.0002且Beta1=0.5。爲了保持生成學習的過程,我們從一個高斯分佈中生成一個修正的批量數據。同時在訓練過程中,我們定期放入修正的噪音給生成器G以提高擬合能力。
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
5.5 訓練
訓練GAN是一種藝術,用不好超參數容易造成模式崩潰。我們通過D建立不同批次圖片的真假差異,以及構建生成G函數以最大化logD(G(z))。
5.5.1 判別器D
訓練判別器D的目的是讓D能最大化識別真假圖片的概率,通過隨機梯度上升(ascending its stochastic gradient SGD)更新判別器。在實踐中就是最大化log(D(x))+log(1-D(G(z)))。
以上步驟分爲兩步實現,第一步是從訓練數據集中拿出一批真實圖片作爲樣本,通過模型D,計算其loss即損失函數log(D(x)),然後再通過反向傳播計算梯度更新損失函數。
第二步是通過生成器建立一批假樣本,也通過D進行前向傳播得到另一半loss值。即損失函數log(1-D(G(z))的值,同時也通過反向傳播更新loss,通過1個batches的迭代更新,我們稱爲一次D的優化(optimizer)
5.5.2 生成器G
在GAN原始版本中G的實現是通過最小化log(1-D(G(z)))以增加更好的造假能力。值得注意的是原始版本並沒有提供足夠的梯度更新策略,特別在早期的訓練學習過程中。作爲修正,我們用最大化log(D(G(z)))來替代原先的策略。其中關鍵名詞如下:
- Loss_D
計算所以批次的真假圖片的判別函數,即loss= log(D(x))+log(D(G(Z))
- Loss_G
生成圖片的損失函數即log(D(G(z)))
- D(x)
輸出真樣本批次的爲真概率,從一開始的1到理論上的擬合至0.5(即G訓練好的時候)
- D(G(z))
判別輸出生成圖片爲真的概率,從一開始的0到理論上擬合至0.5(同爲G訓練好的時候)
訓練時間和訓練整體樣本的次數(epoch),和樣本的大小有關,代碼如下:
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
out:
Starting Training Loop...
[0/5][0/1583] Loss_D: 2.0937 Loss_G: 5.2060 D(x): 0.5704 D(G(z)): 0.6680 / 0.0090
[0/5][50/1583] Loss_D: 0.1916 Loss_G: 9.5846 D(x): 0.9472 D(G(z)): 0.0364 / 0.0002
[0/5][100/1583] Loss_D: 4.0207 Loss_G: 21.2494 D(x): 0.2445 D(G(z)): 0.0000 / 0.0000
[0/5][150/1583] Loss_D: 0.5569 Loss_G: 3.1977 D(x): 0.7294 D(G(z)): 0.0974 / 0.0609
[0/5][200/1583] Loss_D: 0.2320 Loss_G: 3.3187 D(x): 0.9009 D(G(z)): 0.0805 / 0.0659
[0/5][250/1583] Loss_D: 0.7203 Loss_G: 5.9229 D(x): 0.8500 D(G(z)): 0.3485 / 0.0062
[0/5][300/1583] Loss_D: 0.6775 Loss_G: 4.0545 D(x): 0.8330 D(G(z)): 0.3379 / 0.0353
[0/5][350/1583] Loss_D: 0.7549 Loss_G: 5.9064 D(x): 0.9227 D(G(z)): 0.4109 / 0.0084
[0/5][400/1583] Loss_D: 1.0655 Loss_G: 2.5097 D(x): 0.4933 D(G(z)): 0.0269 / 0.1286
[0/5][450/1583] Loss_D: 0.6321 Loss_G: 2.7811 D(x): 0.6453 D(G(z)): 0.0610 / 0.1026
[0/5][500/1583] Loss_D: 0.5064 Loss_G: 4.1399 D(x): 0.9475 D(G(z)): 0.3009 / 0.0350
[0/5][550/1583] Loss_D: 0.3838 Loss_G: 4.0321 D(x): 0.8221 D(G(z)): 0.1218 / 0.0331
[0/5][600/1583] Loss_D: 0.5549 Loss_G: 4.6055 D(x): 0.8230 D(G(z)): 0.2049 / 0.0171
[0/5][650/1583] Loss_D: 0.2821 Loss_G: 6.8137 D(x): 0.8276 D(G(z)): 0.0164 / 0.0027
[0/5][700/1583] Loss_D: 0.6422 Loss_G: 5.0119 D(x): 0.8267 D(G(z)): 0.2827 / 0.0146
[0/5][750/1583] Loss_D: 0.4332 Loss_G: 4.3659 D(x): 0.9239 D(G(z)): 0.2307 / 0.0291
[0/5][800/1583] Loss_D: 0.5344 Loss_G: 3.4145 D(x): 0.7208 D(G(z)): 0.0891 / 0.0744
[0/5][850/1583] Loss_D: 0.8094 Loss_G: 2.9318 D(x): 0.5903 D(G(z)): 0.0602 / 0.0979
[0/5][900/1583] Loss_D: 0.1598 Loss_G: 6.4141 D(x): 0.9228 D(G(z)): 0.0630 / 0.0046
[0/5][950/1583] Loss_D: 0.5083 Loss_G: 5.5467 D(x): 0.9226 D(G(z)): 0.2916 / 0.0112
[0/5][1000/1583] Loss_D: 0.6738 Loss_G: 3.9958 D(x): 0.7622 D(G(z)): 0.2480 / 0.0410
[0/5][1050/1583] Loss_D: 0.2155 Loss_G: 3.8838 D(x): 0.9092 D(G(z)): 0.0819 / 0.0432
[0/5][1100/1583] Loss_D: 1.1708 Loss_G: 1.9610 D(x): 0.4709 D(G(z)): 0.0064 / 0.2448
[0/5][1150/1583] Loss_D: 0.7506 Loss_G: 6.9292 D(x): 0.8797 D(G(z)): 0.3728 / 0.0019
[0/5][1200/1583] Loss_D: 0.2133 Loss_G: 5.5082 D(x): 0.9436 D(G(z)): 0.1272 / 0.0102
[0/5][1250/1583] Loss_D: 0.5156 Loss_G: 3.8660 D(x): 0.8073 D(G(z)): 0.1993 / 0.0357
[0/5][1300/1583] Loss_D: 0.4848 Loss_G: 5.0770 D(x): 0.9170 D(G(z)): 0.2847 / 0.0109
[0/5][1350/1583] Loss_D: 0.6596 Loss_G: 4.7626 D(x): 0.8414 D(G(z)): 0.3232 / 0.0145
[0/5][1400/1583] Loss_D: 0.2799 Loss_G: 5.1604 D(x): 0.9154 D(G(z)): 0.1494 / 0.0156
[0/5][1450/1583] Loss_D: 0.4756 Loss_G: 2.9344 D(x): 0.8164 D(G(z)): 0.1785 / 0.0955
[0/5][1500/1583] Loss_D: 0.3904 Loss_G: 2.3755 D(x): 0.7652 D(G(z)): 0.0587 / 0.1328
[0/5][1550/1583] Loss_D: 1.2817 Loss_G: 1.2689 D(x): 0.3769 D(G(z)): 0.0221 / 0.3693
[1/5][0/1583] Loss_D: 0.5365 Loss_G: 3.0092 D(x): 0.7437 D(G(z)): 0.1574 / 0.0836
[1/5][50/1583] Loss_D: 0.4959 Loss_G: 5.4086 D(x): 0.9422 D(G(z)): 0.2960 / 0.0086
[1/5][100/1583] Loss_D: 0.2685 Loss_G: 3.6553 D(x): 0.8455 D(G(z)): 0.0640 / 0.0457
[1/5][150/1583] Loss_D: 0.6243 Loss_G: 4.6128 D(x): 0.8467 D(G(z)): 0.2878 / 0.0203
[1/5][200/1583] Loss_D: 0.4369 Loss_G: 2.8268 D(x): 0.7591 D(G(z)): 0.0871 / 0.0871
[1/5][250/1583] Loss_D: 0.4244 Loss_G: 3.7669 D(x): 0.8641 D(G(z)): 0.1952 / 0.0369
[1/5][300/1583] Loss_D: 0.7487 Loss_G: 2.5417 D(x): 0.6388 D(G(z)): 0.0948 / 0.1263
[1/5][350/1583] Loss_D: 0.5359 Loss_G: 2.9435 D(x): 0.6996 D(G(z)): 0.0836 / 0.0864
[1/5][400/1583] Loss_D: 0.3469 Loss_G: 2.7581 D(x): 0.8046 D(G(z)): 0.0755 / 0.1036
[1/5][450/1583] Loss_D: 0.5065 Loss_G: 2.8547 D(x): 0.7491 D(G(z)): 0.1494 / 0.0879
[1/5][500/1583] Loss_D: 0.3959 Loss_G: 3.3236 D(x): 0.8292 D(G(z)): 0.1328 / 0.0554
[1/5][550/1583] Loss_D: 0.6679 Loss_G: 5.8782 D(x): 0.9178 D(G(z)): 0.3802 / 0.0075
[1/5][600/1583] Loss_D: 0.8844 Loss_G: 1.9449 D(x): 0.5367 D(G(z)): 0.0326 / 0.1984
[1/5][650/1583] Loss_D: 0.8474 Loss_G: 2.0978 D(x): 0.6395 D(G(z)): 0.1883 / 0.1803
[1/5][700/1583] Loss_D: 0.4682 Loss_G: 5.1056 D(x): 0.8963 D(G(z)): 0.2520 / 0.0137
[1/5][750/1583] Loss_D: 0.4315 Loss_G: 4.0099 D(x): 0.8957 D(G(z)): 0.2441 / 0.0304
[1/5][800/1583] Loss_D: 0.4492 Loss_G: 4.1587 D(x): 0.9090 D(G(z)): 0.2656 / 0.0231
[1/5][850/1583] Loss_D: 0.7694 Loss_G: 1.2065 D(x): 0.5726 D(G(z)): 0.0254 / 0.3785
[1/5][900/1583] Loss_D: 0.3543 Loss_G: 4.0476 D(x): 0.8919 D(G(z)): 0.1873 / 0.0284
[1/5][950/1583] Loss_D: 0.5111 Loss_G: 2.3574 D(x): 0.7082 D(G(z)): 0.0835 / 0.1288
[1/5][1000/1583] Loss_D: 0.5802 Loss_G: 5.4608 D(x): 0.9395 D(G(z)): 0.3649 / 0.0077
[1/5][1050/1583] Loss_D: 1.0051 Loss_G: 2.4068 D(x): 0.5352 D(G(z)): 0.0322 / 0.1486
[1/5][1100/1583] Loss_D: 0.3509 Loss_G: 3.6524 D(x): 0.9101 D(G(z)): 0.2070 / 0.0387
[1/5][1150/1583] Loss_D: 0.9412 Loss_G: 5.4059 D(x): 0.9597 D(G(z)): 0.5325 / 0.0080
[1/5][1200/1583] Loss_D: 0.5332 Loss_G: 3.1298 D(x): 0.7943 D(G(z)): 0.2138 / 0.0630
[1/5][1250/1583] Loss_D: 0.6025 Loss_G: 3.5758 D(x): 0.8679 D(G(z)): 0.3182 / 0.0428
[1/5][1300/1583] Loss_D: 0.7154 Loss_G: 2.1555 D(x): 0.5657 D(G(z)): 0.0379 / 0.1685
[1/5][1350/1583] Loss_D: 0.4168 Loss_G: 2.1878 D(x): 0.7452 D(G(z)): 0.0645 / 0.1534
[1/5][1400/1583] Loss_D: 0.8991 Loss_G: 5.3523 D(x): 0.9256 D(G(z)): 0.4967 / 0.0074
[1/5][1450/1583] Loss_D: 0.4778 Loss_G: 3.8499 D(x): 0.8844 D(G(z)): 0.2655 / 0.0350
[1/5][1500/1583] Loss_D: 0.5049 Loss_G: 2.5450 D(x): 0.7880 D(G(z)): 0.1906 / 0.1010
[1/5][1550/1583] Loss_D: 1.0468 Loss_G: 1.9007 D(x): 0.4378 D(G(z)): 0.0346 / 0.2260
[2/5][0/1583] Loss_D: 0.5008 Loss_G: 3.5294 D(x): 0.9006 D(G(z)): 0.2844 / 0.0466
[2/5][50/1583] Loss_D: 0.5024 Loss_G: 2.3252 D(x): 0.7413 D(G(z)): 0.1450 / 0.1267
[2/5][100/1583] Loss_D: 0.7520 Loss_G: 2.0230 D(x): 0.5753 D(G(z)): 0.0835 / 0.1797
[2/5][150/1583] Loss_D: 0.3734 Loss_G: 2.7221 D(x): 0.8502 D(G(z)): 0.1689 / 0.0889
[2/5][200/1583] Loss_D: 0.5891 Loss_G: 2.6314 D(x): 0.7453 D(G(z)): 0.2076 / 0.1032
[2/5][250/1583] Loss_D: 1.1471 Loss_G: 3.5814 D(x): 0.8959 D(G(z)): 0.5563 / 0.0545
[2/5][300/1583] Loss_D: 0.5756 Loss_G: 3.1905 D(x): 0.8738 D(G(z)): 0.3128 / 0.0605
[2/5][350/1583] Loss_D: 0.5971 Loss_G: 2.9928 D(x): 0.8177 D(G(z)): 0.2657 / 0.0739
[2/5][400/1583] Loss_D: 0.6856 Loss_G: 3.8514 D(x): 0.8880 D(G(z)): 0.3835 / 0.0298
[2/5][450/1583] Loss_D: 0.6088 Loss_G: 1.7919 D(x): 0.6660 D(G(z)): 0.1227 / 0.2189
[2/5][500/1583] Loss_D: 0.7147 Loss_G: 2.6453 D(x): 0.8321 D(G(z)): 0.3531 / 0.1007
[2/5][550/1583] Loss_D: 0.5759 Loss_G: 2.9074 D(x): 0.8269 D(G(z)): 0.2833 / 0.0738
[2/5][600/1583] Loss_D: 0.5678 Loss_G: 2.6149 D(x): 0.7928 D(G(z)): 0.2516 / 0.0956
[2/5][650/1583] Loss_D: 0.9501 Loss_G: 1.1814 D(x): 0.5916 D(G(z)): 0.2322 / 0.3815
[2/5][700/1583] Loss_D: 0.4551 Loss_G: 2.5074 D(x): 0.8331 D(G(z)): 0.2047 / 0.1129
[2/5][750/1583] Loss_D: 0.4560 Loss_G: 2.3947 D(x): 0.7525 D(G(z)): 0.1240 / 0.1147
[2/5][800/1583] Loss_D: 1.1853 Loss_G: 5.1657 D(x): 0.9202 D(G(z)): 0.6049 / 0.0091
[2/5][850/1583] Loss_D: 0.5514 Loss_G: 3.0085 D(x): 0.8497 D(G(z)): 0.2890 / 0.0685
[2/5][900/1583] Loss_D: 0.6882 Loss_G: 1.8971 D(x): 0.6970 D(G(z)): 0.2332 / 0.1909
[2/5][950/1583] Loss_D: 1.1220 Loss_G: 0.7904 D(x): 0.4095 D(G(z)): 0.0570 / 0.4975
[2/5][1000/1583] Loss_D: 1.3335 Loss_G: 0.3115 D(x): 0.3347 D(G(z)): 0.0262 / 0.7661
[2/5][1050/1583] Loss_D: 1.7281 Loss_G: 0.8212 D(x): 0.2437 D(G(z)): 0.0261 / 0.5179
[2/5][1100/1583] Loss_D: 0.9401 Loss_G: 3.7894 D(x): 0.9033 D(G(z)): 0.5104 / 0.0349
[2/5][1150/1583] Loss_D: 0.8078 Loss_G: 3.9862 D(x): 0.9178 D(G(z)): 0.4608 / 0.0286
[2/5][1200/1583] Loss_D: 0.5182 Loss_G: 3.1859 D(x): 0.8568 D(G(z)): 0.2787 / 0.0554
[2/5][1250/1583] Loss_D: 0.5092 Loss_G: 2.3530 D(x): 0.8015 D(G(z)): 0.2122 / 0.1188
[2/5][1300/1583] Loss_D: 1.2668 Loss_G: 0.5543 D(x): 0.3424 D(G(z)): 0.0165 / 0.6271
[2/5][1350/1583] Loss_D: 0.7197 Loss_G: 3.8595 D(x): 0.9043 D(G(z)): 0.4208 / 0.0299
[2/5][1400/1583] Loss_D: 0.5428 Loss_G: 2.6526 D(x): 0.8873 D(G(z)): 0.3056 / 0.0961
[2/5][1450/1583] Loss_D: 0.6610 Loss_G: 4.2385 D(x): 0.9272 D(G(z)): 0.3985 / 0.0211
[2/5][1500/1583] Loss_D: 0.8172 Loss_G: 3.2164 D(x): 0.8811 D(G(z)): 0.4422 / 0.0612
[2/5][1550/1583] Loss_D: 0.6449 Loss_G: 3.8452 D(x): 0.9130 D(G(z)): 0.3813 / 0.0325
[3/5][0/1583] Loss_D: 0.7677 Loss_G: 1.7745 D(x): 0.5928 D(G(z)): 0.1388 / 0.2182
[3/5][50/1583] Loss_D: 0.7981 Loss_G: 2.9624 D(x): 0.8315 D(G(z)): 0.4131 / 0.0735
[3/5][100/1583] Loss_D: 0.5679 Loss_G: 1.8958 D(x): 0.7173 D(G(z)): 0.1667 / 0.1914
[3/5][150/1583] Loss_D: 0.8576 Loss_G: 1.5904 D(x): 0.5391 D(G(z)): 0.1158 / 0.2699
[3/5][200/1583] Loss_D: 0.8644 Loss_G: 1.6487 D(x): 0.5868 D(G(z)): 0.1933 / 0.2319
[3/5][250/1583] Loss_D: 0.5331 Loss_G: 3.0401 D(x): 0.8831 D(G(z)): 0.3022 / 0.0608
[3/5][300/1583] Loss_D: 1.2449 Loss_G: 2.9489 D(x): 0.8759 D(G(z)): 0.5865 / 0.0828
[3/5][350/1583] Loss_D: 1.7188 Loss_G: 0.5466 D(x): 0.2664 D(G(z)): 0.0539 / 0.6320
[3/5][400/1583] Loss_D: 0.5794 Loss_G: 2.7556 D(x): 0.7984 D(G(z)): 0.2640 / 0.0787
[3/5][450/1583] Loss_D: 0.6916 Loss_G: 3.1434 D(x): 0.8813 D(G(z)): 0.3955 / 0.0578
[3/5][500/1583] Loss_D: 0.8415 Loss_G: 1.9770 D(x): 0.6981 D(G(z)): 0.3120 / 0.1639
[3/5][550/1583] Loss_D: 0.6394 Loss_G: 2.4790 D(x): 0.8093 D(G(z)): 0.2990 / 0.1082
[3/5][600/1583] Loss_D: 0.7545 Loss_G: 1.6259 D(x): 0.6042 D(G(z)): 0.1454 / 0.2401
[3/5][650/1583] Loss_D: 0.5494 Loss_G: 2.1957 D(x): 0.8292 D(G(z)): 0.2727 / 0.1414
[3/5][700/1583] Loss_D: 1.5095 Loss_G: 5.1368 D(x): 0.9269 D(G(z)): 0.6897 / 0.0095
[3/5][750/1583] Loss_D: 0.4714 Loss_G: 2.1401 D(x): 0.8137 D(G(z)): 0.2101 / 0.1501
[3/5][800/1583] Loss_D: 0.7118 Loss_G: 3.2356 D(x): 0.8190 D(G(z)): 0.3579 / 0.0540
[3/5][850/1583] Loss_D: 0.6392 Loss_G: 1.6740 D(x): 0.6650 D(G(z)): 0.1402 / 0.2391
[3/5][900/1583] Loss_D: 0.5303 Loss_G: 2.8854 D(x): 0.7900 D(G(z)): 0.2204 / 0.0740
[3/5][950/1583] Loss_D: 0.6333 Loss_G: 2.1030 D(x): 0.6946 D(G(z)): 0.1882 / 0.1572
[3/5][1000/1583] Loss_D: 0.8715 Loss_G: 1.6630 D(x): 0.5222 D(G(z)): 0.0890 / 0.2590
[3/5][1050/1583] Loss_D: 0.6139 Loss_G: 3.1772 D(x): 0.8609 D(G(z)): 0.3400 / 0.0558
[3/5][1100/1583] Loss_D: 0.6673 Loss_G: 3.4143 D(x): 0.9044 D(G(z)): 0.3910 / 0.0435
[3/5][1150/1583] Loss_D: 0.6554 Loss_G: 3.4282 D(x): 0.8429 D(G(z)): 0.3347 / 0.0484
[3/5][1200/1583] Loss_D: 0.6184 Loss_G: 1.7371 D(x): 0.6531 D(G(z)): 0.1177 / 0.2132
[3/5][1250/1583] Loss_D: 0.8293 Loss_G: 3.1246 D(x): 0.7821 D(G(z)): 0.3883 / 0.0594
[3/5][1300/1583] Loss_D: 0.5211 Loss_G: 2.0112 D(x): 0.7308 D(G(z)): 0.1503 / 0.1637
[3/5][1350/1583] Loss_D: 0.7389 Loss_G: 1.4238 D(x): 0.5854 D(G(z)): 0.1181 / 0.2935
[3/5][1400/1583] Loss_D: 0.6608 Loss_G: 3.1928 D(x): 0.7803 D(G(z)): 0.2922 / 0.0580
[3/5][1450/1583] Loss_D: 0.6381 Loss_G: 3.4123 D(x): 0.8340 D(G(z)): 0.3337 / 0.0450
[3/5][1500/1583] Loss_D: 0.7027 Loss_G: 3.1943 D(x): 0.9058 D(G(z)): 0.4113 / 0.0556
[3/5][1550/1583] Loss_D: 0.6849 Loss_G: 2.9714 D(x): 0.8258 D(G(z)): 0.3499 / 0.0704
[4/5][0/1583] Loss_D: 0.7685 Loss_G: 1.7204 D(x): 0.5788 D(G(z)): 0.1084 / 0.2252
[4/5][50/1583] Loss_D: 0.6194 Loss_G: 1.4702 D(x): 0.6214 D(G(z)): 0.0700 / 0.2812
[4/5][100/1583] Loss_D: 0.5243 Loss_G: 2.4332 D(x): 0.8206 D(G(z)): 0.2515 / 0.1099
[4/5][150/1583] Loss_D: 0.8506 Loss_G: 1.0129 D(x): 0.5094 D(G(z)): 0.0647 / 0.4126
[4/5][200/1583] Loss_D: 1.1715 Loss_G: 2.5120 D(x): 0.5642 D(G(z)): 0.3481 / 0.1214
[4/5][250/1583] Loss_D: 0.4317 Loss_G: 2.7731 D(x): 0.8405 D(G(z)): 0.2088 / 0.0791
[4/5][300/1583] Loss_D: 1.2310 Loss_G: 0.4177 D(x): 0.3812 D(G(z)): 0.0576 / 0.6799
[4/5][350/1583] Loss_D: 0.5565 Loss_G: 2.7405 D(x): 0.8525 D(G(z)): 0.3005 / 0.0810
[4/5][400/1583] Loss_D: 0.4918 Loss_G: 3.5705 D(x): 0.8863 D(G(z)): 0.2833 / 0.0371
[4/5][450/1583] Loss_D: 0.6403 Loss_G: 2.7691 D(x): 0.8543 D(G(z)): 0.3406 / 0.0812
[4/5][500/1583] Loss_D: 0.5944 Loss_G: 1.4696 D(x): 0.6849 D(G(z)): 0.1325 / 0.2682
[4/5][550/1583] Loss_D: 0.8678 Loss_G: 4.1990 D(x): 0.9529 D(G(z)): 0.5105 / 0.0202
[4/5][600/1583] Loss_D: 0.8326 Loss_G: 1.1841 D(x): 0.5175 D(G(z)): 0.0679 / 0.3628
[4/5][650/1583] Loss_D: 0.5198 Loss_G: 2.4393 D(x): 0.7668 D(G(z)): 0.1943 / 0.1148
[4/5][700/1583] Loss_D: 0.8029 Loss_G: 4.0836 D(x): 0.8791 D(G(z)): 0.4448 / 0.0229
[4/5][750/1583] Loss_D: 0.8636 Loss_G: 2.0386 D(x): 0.5234 D(G(z)): 0.0899 / 0.1846
[4/5][800/1583] Loss_D: 0.5041 Loss_G: 3.0354 D(x): 0.8302 D(G(z)): 0.2301 / 0.0609
[4/5][850/1583] Loss_D: 0.7514 Loss_G: 1.2513 D(x): 0.5578 D(G(z)): 0.0899 / 0.3480
[4/5][900/1583] Loss_D: 0.6650 Loss_G: 1.2806 D(x): 0.6675 D(G(z)): 0.1925 / 0.3201
[4/5][950/1583] Loss_D: 0.5754 Loss_G: 3.0898 D(x): 0.8730 D(G(z)): 0.3233 / 0.0597
[4/5][1000/1583] Loss_D: 0.9327 Loss_G: 0.7588 D(x): 0.4674 D(G(z)): 0.0434 / 0.5174
[4/5][1050/1583] Loss_D: 0.9255 Loss_G: 0.9513 D(x): 0.5029 D(G(z)): 0.1161 / 0.4196
[4/5][1100/1583] Loss_D: 0.6573 Loss_G: 3.4663 D(x): 0.8755 D(G(z)): 0.3674 / 0.0403
[4/5][1150/1583] Loss_D: 0.9803 Loss_G: 1.2451 D(x): 0.4602 D(G(z)): 0.0978 / 0.3432
[4/5][1200/1583] Loss_D: 0.5560 Loss_G: 2.5421 D(x): 0.7617 D(G(z)): 0.2097 / 0.1020
[4/5][1250/1583] Loss_D: 0.7573 Loss_G: 1.9034 D(x): 0.6477 D(G(z)): 0.2158 / 0.1890
[4/5][1300/1583] Loss_D: 0.4733 Loss_G: 2.7071 D(x): 0.8271 D(G(z)): 0.2169 / 0.0882
[4/5][1350/1583] Loss_D: 1.0812 Loss_G: 1.1500 D(x): 0.5225 D(G(z)): 0.2278 / 0.3626
[4/5][1400/1583] Loss_D: 1.5454 Loss_G: 5.2881 D(x): 0.9620 D(G(z)): 0.7085 / 0.0089
[4/5][1450/1583] Loss_D: 0.3576 Loss_G: 3.1023 D(x): 0.8687 D(G(z)): 0.1726 / 0.0584
[4/5][1500/1583] Loss_D: 0.5330 Loss_G: 1.9979 D(x): 0.7277 D(G(z)): 0.1597 / 0.1680
[4/5][1550/1583] Loss_D: 0.8927 Loss_G: 4.1379 D(x): 0.9345 D(G(z)): 0.5081 / 0.0224
5.6 結果
從三個不同方面看實驗結果:
- 看G和D兩個損失函數的變化
- 看每輪epoch訓練G生成圖片的結果
- 對比一批生成圖片和一批真實圖片(64張)
a.loss變化
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
b.圖片生成變化
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
c.對比真假圖片
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
6.下一步
- Train for longer to see how good the results get
多訓練幾次,如增加epoch看效果
- Modify this model to take a different dataset and possibly change the size of the images and the model architecture
換其他數據集、或者調整一些模型結構
- Check out some other cool GAN projects here
試試其他有趣的GAN應用–https://github.com/nashory/gans-awesome-applications
- Create GANs that generate music
用GAN生成音樂
7.參考:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#
https://github.com/soumith/ganhacks#authors