DCGAN的PyTorch實現

DCGAN

1.什麼是GAN

GAN是一個框架,讓深度模型可以學習到數據的分佈,從而通過數據的分佈生成新的數據(服從同一分佈)。

其由一個判別器和一個生成器構成,生成器負責生成“仿造數據”,判別器負責判斷“仿造數據”的質量。兩者一起進化,導致造假貨和識別假貨的兩個模型G/D都能有超強的造假和識別假貨的能力。

最終訓練達到類似納什均衡的平衡狀態,就是分辨器已經分辨不出真假,其分別真假的成功率只有50%(和瞎猜沒有區別)。

假設原數據分佈爲x(可以是一張真實圖片等多維數據),判別器D(),隨機變量Z,生成器爲G()。D(x)生成一個標量代表x來自真實分佈的概率。Z是一個隨機噪聲,G(Z)代表隨機噪聲Z(也稱爲隱空間向量)到真實分佈P_data的映射。G(Z)的生成數據的概率分佈記作P_G.

所以D(G(z))就是一個標量代表其生成圖片是真實圖片的概率
,同時D和G在玩一個你最小(G)我最大(D)的遊戲。D想把自己分別真假圖片x的成功率最大化

logD(x)

G想把造假圖片z和真實圖片x的差距最小化

log(1-D(G(x))。

總目標函數(loss function)可以寫成:

image

2.什麼是DCGAN

DCGAN是GAN的一個擴展,卷積網絡做判別器,反捲積做生成器。

判別器通過大幅步的卷積網絡、批量正則化、LeakyRelu激活函數構成。輸入一個3*64 *64的圖片,輸出一個真假概率值。

生成器由一個反捲積網絡、批量正則化、Relu激活函數構成,通過輸入一個隱變量z(如標準正態分佈)。同時輸出一個3*64 *64的圖片。

同時《 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks》的原作者還給出如何設置優化器(optimizers),如何計算損失函數,如何初始化模型weights等技巧。

初始導入代碼如下:

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

3.輸入設置

輸入參數設置

  • dataroot - the path to the root of the dataset folder. We will talk more about the dataset in the next section
  • workers - the number of worker threads for loading the data with the DataLoader
  • batch_size - the batch size used in training. The DCGAN paper uses a batch size of 128
  • image_size - the spatial size of the images used for training. This implementation defaults to 64x64. If another size is desired, the structures of D and G must be changed.
  • nc - number of color channels in the input images. For color images this is 3
  • nz - length of latent vector
  • ngf - relates to the depth of feature maps carried through the generator
  • ndf - sets the depth of feature maps propagated through the discriminator
  • num_epochs - number of training epochs to run. Training for longer will probably lead to better results but will also take much longer
  • lr - learning rate for training. As described in the DCGAN paper, this number should be 0.0002
  • beta1 - beta1 hyperparameter for Adam optimizers. As described in paper, this number should be 0.5
  • ngpu - number of GPUs available. If this is 0, code will run in CPU mode. If this number is greater than 0 it will run on that number of GPUs
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

4.數據

數據集用的是港中文的Celeb-A

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
#real_batch是一個列表
#第一個元素real_batch[0]是[128,3,64,64]的tensor,就是標準的一個batch的4D結構:128張圖,3個通道,64長,64寬
#第二個元素real_batch[1]是第一個元素的標籤,有128個label值全爲0
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

#這個函數能讓圖片顯示
#plt.show() 

在這裏插入圖片描述

5.實現(Implementation)

5.1 參數初始化(Weight Initialization)

w初始化爲均值爲0,標準差爲0.02的正態分佈

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

5.2 生成器(Generator)

生成器G是構造一個由向量Z(隱空間)到真實數據空間的映射(map)

  • nz=100,z輸入時的長度

  • nc=3,輸出時的chanel,彩色是RGB三通道

  • ngf=64,指的是生成的特徵爲64*64

  • 反捲積的函數爲:

ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)

參數爲:1.輸入、2.輸出、3.核函數、4.卷積核步數、5.輸入邊填充、6.輸出邊填充、7.group、8.偏置、9.膨脹

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            #輸入100,輸出64*8,核函數是4*4
            
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

實例化生成器,初始化參數w

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

out:

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

5.3 判別器(Discriminator)

判別器D是一個二元分類器,判別輸入的圖片真假。通過輸入圖片進入一連串的卷積層中,經過卷積(Strided Convolution)、批量正則(BatchNorm)、LeakyReLu激活,最終通過Sigmoid激活函數輸出一個概率選擇。

以上的結構如有必要可以擴展更多的層,不過DCGAN的設計者通過實驗發現調整步幅的卷積層比池化的下采樣效果要好,因爲通過卷積網絡可以學習到自己的池化函數。同時批量正則化和leakly relu函數都可以提高梯度下降的質量,這些效果在同時訓練G和D時顯得更爲突出。

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

構建D,並初始化w方程,並且輸出模型的結構。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

out:

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

5.4 損失函數&優化器(loss&optimizer)

用Pytorch自帶的損失函數Binary Corss Entropy(BCELoss),其定義如下:

image

我們定義真圖片real爲1,假圖片fake爲0。同時設置兩個優化器optimizer。在本例中
都是adam優化器,其學習率是0.0002且Beta1=0.5。爲了保持生成學習的過程,我們從一個高斯分佈中生成一個修正的批量數據。同時在訓練過程中,我們定期放入修正的噪音給生成器G以提高擬合能力。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

5.5 訓練

訓練GAN是一種藝術,用不好超參數容易造成模式崩潰。我們通過D建立不同批次圖片的真假差異,以及構建生成G函數以最大化logD(G(z))。

5.5.1 判別器D

訓練判別器D的目的是讓D能最大化識別真假圖片的概率,通過隨機梯度上升(ascending its stochastic gradient SGD)更新判別器。在實踐中就是最大化log(D(x))+log(1-D(G(z)))。

以上步驟分爲兩步實現,第一步是從訓練數據集中拿出一批真實圖片作爲樣本,通過模型D,計算其loss即損失函數log(D(x)),然後再通過反向傳播計算梯度更新損失函數。

第二步是通過生成器建立一批假樣本,也通過D進行前向傳播得到另一半loss值。即損失函數log(1-D(G(z))的值,同時也通過反向傳播更新loss,通過1個batches的迭代更新,我們稱爲一次D的優化(optimizer)

5.5.2 生成器G

在這裏插入圖片描述

在GAN原始版本中G的實現是通過最小化log(1-D(G(z)))以增加更好的造假能力。值得注意的是原始版本並沒有提供足夠的梯度更新策略,特別在早期的訓練學習過程中。作爲修正,我們用最大化log(D(G(z)))來替代原先的策略。其中關鍵名詞如下:

  • Loss_D

計算所以批次的真假圖片的判別函數,即loss= log(D(x))+log(D(G(Z))

  • Loss_G

生成圖片的損失函數即log(D(G(z)))

  • D(x)

輸出真樣本批次的爲真概率,從一開始的1到理論上的擬合至0.5(即G訓練好的時候)

  • D(G(z))

判別輸出生成圖片爲真的概率,從一開始的0到理論上擬合至0.5(同爲G訓練好的時候)

訓練時間和訓練整體樣本的次數(epoch),和樣本的大小有關,代碼如下:

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

out:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 2.0937  Loss_G: 5.2060  D(x): 0.5704    D(G(z)): 0.6680 / 0.0090
[0/5][50/1583]  Loss_D: 0.1916  Loss_G: 9.5846  D(x): 0.9472    D(G(z)): 0.0364 / 0.0002
[0/5][100/1583] Loss_D: 4.0207  Loss_G: 21.2494 D(x): 0.2445    D(G(z)): 0.0000 / 0.0000
[0/5][150/1583] Loss_D: 0.5569  Loss_G: 3.1977  D(x): 0.7294    D(G(z)): 0.0974 / 0.0609
[0/5][200/1583] Loss_D: 0.2320  Loss_G: 3.3187  D(x): 0.9009    D(G(z)): 0.0805 / 0.0659
[0/5][250/1583] Loss_D: 0.7203  Loss_G: 5.9229  D(x): 0.8500    D(G(z)): 0.3485 / 0.0062
[0/5][300/1583] Loss_D: 0.6775  Loss_G: 4.0545  D(x): 0.8330    D(G(z)): 0.3379 / 0.0353
[0/5][350/1583] Loss_D: 0.7549  Loss_G: 5.9064  D(x): 0.9227    D(G(z)): 0.4109 / 0.0084
[0/5][400/1583] Loss_D: 1.0655  Loss_G: 2.5097  D(x): 0.4933    D(G(z)): 0.0269 / 0.1286
[0/5][450/1583] Loss_D: 0.6321  Loss_G: 2.7811  D(x): 0.6453    D(G(z)): 0.0610 / 0.1026
[0/5][500/1583] Loss_D: 0.5064  Loss_G: 4.1399  D(x): 0.9475    D(G(z)): 0.3009 / 0.0350
[0/5][550/1583] Loss_D: 0.3838  Loss_G: 4.0321  D(x): 0.8221    D(G(z)): 0.1218 / 0.0331
[0/5][600/1583] Loss_D: 0.5549  Loss_G: 4.6055  D(x): 0.8230    D(G(z)): 0.2049 / 0.0171
[0/5][650/1583] Loss_D: 0.2821  Loss_G: 6.8137  D(x): 0.8276    D(G(z)): 0.0164 / 0.0027
[0/5][700/1583] Loss_D: 0.6422  Loss_G: 5.0119  D(x): 0.8267    D(G(z)): 0.2827 / 0.0146
[0/5][750/1583] Loss_D: 0.4332  Loss_G: 4.3659  D(x): 0.9239    D(G(z)): 0.2307 / 0.0291
[0/5][800/1583] Loss_D: 0.5344  Loss_G: 3.4145  D(x): 0.7208    D(G(z)): 0.0891 / 0.0744
[0/5][850/1583] Loss_D: 0.8094  Loss_G: 2.9318  D(x): 0.5903    D(G(z)): 0.0602 / 0.0979
[0/5][900/1583] Loss_D: 0.1598  Loss_G: 6.4141  D(x): 0.9228    D(G(z)): 0.0630 / 0.0046
[0/5][950/1583] Loss_D: 0.5083  Loss_G: 5.5467  D(x): 0.9226    D(G(z)): 0.2916 / 0.0112
[0/5][1000/1583]        Loss_D: 0.6738  Loss_G: 3.9958  D(x): 0.7622    D(G(z)): 0.2480 / 0.0410
[0/5][1050/1583]        Loss_D: 0.2155  Loss_G: 3.8838  D(x): 0.9092    D(G(z)): 0.0819 / 0.0432
[0/5][1100/1583]        Loss_D: 1.1708  Loss_G: 1.9610  D(x): 0.4709    D(G(z)): 0.0064 / 0.2448
[0/5][1150/1583]        Loss_D: 0.7506  Loss_G: 6.9292  D(x): 0.8797    D(G(z)): 0.3728 / 0.0019
[0/5][1200/1583]        Loss_D: 0.2133  Loss_G: 5.5082  D(x): 0.9436    D(G(z)): 0.1272 / 0.0102
[0/5][1250/1583]        Loss_D: 0.5156  Loss_G: 3.8660  D(x): 0.8073    D(G(z)): 0.1993 / 0.0357
[0/5][1300/1583]        Loss_D: 0.4848  Loss_G: 5.0770  D(x): 0.9170    D(G(z)): 0.2847 / 0.0109
[0/5][1350/1583]        Loss_D: 0.6596  Loss_G: 4.7626  D(x): 0.8414    D(G(z)): 0.3232 / 0.0145
[0/5][1400/1583]        Loss_D: 0.2799  Loss_G: 5.1604  D(x): 0.9154    D(G(z)): 0.1494 / 0.0156
[0/5][1450/1583]        Loss_D: 0.4756  Loss_G: 2.9344  D(x): 0.8164    D(G(z)): 0.1785 / 0.0955
[0/5][1500/1583]        Loss_D: 0.3904  Loss_G: 2.3755  D(x): 0.7652    D(G(z)): 0.0587 / 0.1328
[0/5][1550/1583]        Loss_D: 1.2817  Loss_G: 1.2689  D(x): 0.3769    D(G(z)): 0.0221 / 0.3693
[1/5][0/1583]   Loss_D: 0.5365  Loss_G: 3.0092  D(x): 0.7437    D(G(z)): 0.1574 / 0.0836
[1/5][50/1583]  Loss_D: 0.4959  Loss_G: 5.4086  D(x): 0.9422    D(G(z)): 0.2960 / 0.0086
[1/5][100/1583] Loss_D: 0.2685  Loss_G: 3.6553  D(x): 0.8455    D(G(z)): 0.0640 / 0.0457
[1/5][150/1583] Loss_D: 0.6243  Loss_G: 4.6128  D(x): 0.8467    D(G(z)): 0.2878 / 0.0203
[1/5][200/1583] Loss_D: 0.4369  Loss_G: 2.8268  D(x): 0.7591    D(G(z)): 0.0871 / 0.0871
[1/5][250/1583] Loss_D: 0.4244  Loss_G: 3.7669  D(x): 0.8641    D(G(z)): 0.1952 / 0.0369
[1/5][300/1583] Loss_D: 0.7487  Loss_G: 2.5417  D(x): 0.6388    D(G(z)): 0.0948 / 0.1263
[1/5][350/1583] Loss_D: 0.5359  Loss_G: 2.9435  D(x): 0.6996    D(G(z)): 0.0836 / 0.0864
[1/5][400/1583] Loss_D: 0.3469  Loss_G: 2.7581  D(x): 0.8046    D(G(z)): 0.0755 / 0.1036
[1/5][450/1583] Loss_D: 0.5065  Loss_G: 2.8547  D(x): 0.7491    D(G(z)): 0.1494 / 0.0879
[1/5][500/1583] Loss_D: 0.3959  Loss_G: 3.3236  D(x): 0.8292    D(G(z)): 0.1328 / 0.0554
[1/5][550/1583] Loss_D: 0.6679  Loss_G: 5.8782  D(x): 0.9178    D(G(z)): 0.3802 / 0.0075
[1/5][600/1583] Loss_D: 0.8844  Loss_G: 1.9449  D(x): 0.5367    D(G(z)): 0.0326 / 0.1984
[1/5][650/1583] Loss_D: 0.8474  Loss_G: 2.0978  D(x): 0.6395    D(G(z)): 0.1883 / 0.1803
[1/5][700/1583] Loss_D: 0.4682  Loss_G: 5.1056  D(x): 0.8963    D(G(z)): 0.2520 / 0.0137
[1/5][750/1583] Loss_D: 0.4315  Loss_G: 4.0099  D(x): 0.8957    D(G(z)): 0.2441 / 0.0304
[1/5][800/1583] Loss_D: 0.4492  Loss_G: 4.1587  D(x): 0.9090    D(G(z)): 0.2656 / 0.0231
[1/5][850/1583] Loss_D: 0.7694  Loss_G: 1.2065  D(x): 0.5726    D(G(z)): 0.0254 / 0.3785
[1/5][900/1583] Loss_D: 0.3543  Loss_G: 4.0476  D(x): 0.8919    D(G(z)): 0.1873 / 0.0284
[1/5][950/1583] Loss_D: 0.5111  Loss_G: 2.3574  D(x): 0.7082    D(G(z)): 0.0835 / 0.1288
[1/5][1000/1583]        Loss_D: 0.5802  Loss_G: 5.4608  D(x): 0.9395    D(G(z)): 0.3649 / 0.0077
[1/5][1050/1583]        Loss_D: 1.0051  Loss_G: 2.4068  D(x): 0.5352    D(G(z)): 0.0322 / 0.1486
[1/5][1100/1583]        Loss_D: 0.3509  Loss_G: 3.6524  D(x): 0.9101    D(G(z)): 0.2070 / 0.0387
[1/5][1150/1583]        Loss_D: 0.9412  Loss_G: 5.4059  D(x): 0.9597    D(G(z)): 0.5325 / 0.0080
[1/5][1200/1583]        Loss_D: 0.5332  Loss_G: 3.1298  D(x): 0.7943    D(G(z)): 0.2138 / 0.0630
[1/5][1250/1583]        Loss_D: 0.6025  Loss_G: 3.5758  D(x): 0.8679    D(G(z)): 0.3182 / 0.0428
[1/5][1300/1583]        Loss_D: 0.7154  Loss_G: 2.1555  D(x): 0.5657    D(G(z)): 0.0379 / 0.1685
[1/5][1350/1583]        Loss_D: 0.4168  Loss_G: 2.1878  D(x): 0.7452    D(G(z)): 0.0645 / 0.1534
[1/5][1400/1583]        Loss_D: 0.8991  Loss_G: 5.3523  D(x): 0.9256    D(G(z)): 0.4967 / 0.0074
[1/5][1450/1583]        Loss_D: 0.4778  Loss_G: 3.8499  D(x): 0.8844    D(G(z)): 0.2655 / 0.0350
[1/5][1500/1583]        Loss_D: 0.5049  Loss_G: 2.5450  D(x): 0.7880    D(G(z)): 0.1906 / 0.1010
[1/5][1550/1583]        Loss_D: 1.0468  Loss_G: 1.9007  D(x): 0.4378    D(G(z)): 0.0346 / 0.2260
[2/5][0/1583]   Loss_D: 0.5008  Loss_G: 3.5294  D(x): 0.9006    D(G(z)): 0.2844 / 0.0466
[2/5][50/1583]  Loss_D: 0.5024  Loss_G: 2.3252  D(x): 0.7413    D(G(z)): 0.1450 / 0.1267
[2/5][100/1583] Loss_D: 0.7520  Loss_G: 2.0230  D(x): 0.5753    D(G(z)): 0.0835 / 0.1797
[2/5][150/1583] Loss_D: 0.3734  Loss_G: 2.7221  D(x): 0.8502    D(G(z)): 0.1689 / 0.0889
[2/5][200/1583] Loss_D: 0.5891  Loss_G: 2.6314  D(x): 0.7453    D(G(z)): 0.2076 / 0.1032
[2/5][250/1583] Loss_D: 1.1471  Loss_G: 3.5814  D(x): 0.8959    D(G(z)): 0.5563 / 0.0545
[2/5][300/1583] Loss_D: 0.5756  Loss_G: 3.1905  D(x): 0.8738    D(G(z)): 0.3128 / 0.0605
[2/5][350/1583] Loss_D: 0.5971  Loss_G: 2.9928  D(x): 0.8177    D(G(z)): 0.2657 / 0.0739
[2/5][400/1583] Loss_D: 0.6856  Loss_G: 3.8514  D(x): 0.8880    D(G(z)): 0.3835 / 0.0298
[2/5][450/1583] Loss_D: 0.6088  Loss_G: 1.7919  D(x): 0.6660    D(G(z)): 0.1227 / 0.2189
[2/5][500/1583] Loss_D: 0.7147  Loss_G: 2.6453  D(x): 0.8321    D(G(z)): 0.3531 / 0.1007
[2/5][550/1583] Loss_D: 0.5759  Loss_G: 2.9074  D(x): 0.8269    D(G(z)): 0.2833 / 0.0738
[2/5][600/1583] Loss_D: 0.5678  Loss_G: 2.6149  D(x): 0.7928    D(G(z)): 0.2516 / 0.0956
[2/5][650/1583] Loss_D: 0.9501  Loss_G: 1.1814  D(x): 0.5916    D(G(z)): 0.2322 / 0.3815
[2/5][700/1583] Loss_D: 0.4551  Loss_G: 2.5074  D(x): 0.8331    D(G(z)): 0.2047 / 0.1129
[2/5][750/1583] Loss_D: 0.4560  Loss_G: 2.3947  D(x): 0.7525    D(G(z)): 0.1240 / 0.1147
[2/5][800/1583] Loss_D: 1.1853  Loss_G: 5.1657  D(x): 0.9202    D(G(z)): 0.6049 / 0.0091
[2/5][850/1583] Loss_D: 0.5514  Loss_G: 3.0085  D(x): 0.8497    D(G(z)): 0.2890 / 0.0685
[2/5][900/1583] Loss_D: 0.6882  Loss_G: 1.8971  D(x): 0.6970    D(G(z)): 0.2332 / 0.1909
[2/5][950/1583] Loss_D: 1.1220  Loss_G: 0.7904  D(x): 0.4095    D(G(z)): 0.0570 / 0.4975
[2/5][1000/1583]        Loss_D: 1.3335  Loss_G: 0.3115  D(x): 0.3347    D(G(z)): 0.0262 / 0.7661
[2/5][1050/1583]        Loss_D: 1.7281  Loss_G: 0.8212  D(x): 0.2437    D(G(z)): 0.0261 / 0.5179
[2/5][1100/1583]        Loss_D: 0.9401  Loss_G: 3.7894  D(x): 0.9033    D(G(z)): 0.5104 / 0.0349
[2/5][1150/1583]        Loss_D: 0.8078  Loss_G: 3.9862  D(x): 0.9178    D(G(z)): 0.4608 / 0.0286
[2/5][1200/1583]        Loss_D: 0.5182  Loss_G: 3.1859  D(x): 0.8568    D(G(z)): 0.2787 / 0.0554
[2/5][1250/1583]        Loss_D: 0.5092  Loss_G: 2.3530  D(x): 0.8015    D(G(z)): 0.2122 / 0.1188
[2/5][1300/1583]        Loss_D: 1.2668  Loss_G: 0.5543  D(x): 0.3424    D(G(z)): 0.0165 / 0.6271
[2/5][1350/1583]        Loss_D: 0.7197  Loss_G: 3.8595  D(x): 0.9043    D(G(z)): 0.4208 / 0.0299
[2/5][1400/1583]        Loss_D: 0.5428  Loss_G: 2.6526  D(x): 0.8873    D(G(z)): 0.3056 / 0.0961
[2/5][1450/1583]        Loss_D: 0.6610  Loss_G: 4.2385  D(x): 0.9272    D(G(z)): 0.3985 / 0.0211
[2/5][1500/1583]        Loss_D: 0.8172  Loss_G: 3.2164  D(x): 0.8811    D(G(z)): 0.4422 / 0.0612
[2/5][1550/1583]        Loss_D: 0.6449  Loss_G: 3.8452  D(x): 0.9130    D(G(z)): 0.3813 / 0.0325
[3/5][0/1583]   Loss_D: 0.7677  Loss_G: 1.7745  D(x): 0.5928    D(G(z)): 0.1388 / 0.2182
[3/5][50/1583]  Loss_D: 0.7981  Loss_G: 2.9624  D(x): 0.8315    D(G(z)): 0.4131 / 0.0735
[3/5][100/1583] Loss_D: 0.5679  Loss_G: 1.8958  D(x): 0.7173    D(G(z)): 0.1667 / 0.1914
[3/5][150/1583] Loss_D: 0.8576  Loss_G: 1.5904  D(x): 0.5391    D(G(z)): 0.1158 / 0.2699
[3/5][200/1583] Loss_D: 0.8644  Loss_G: 1.6487  D(x): 0.5868    D(G(z)): 0.1933 / 0.2319
[3/5][250/1583] Loss_D: 0.5331  Loss_G: 3.0401  D(x): 0.8831    D(G(z)): 0.3022 / 0.0608
[3/5][300/1583] Loss_D: 1.2449  Loss_G: 2.9489  D(x): 0.8759    D(G(z)): 0.5865 / 0.0828
[3/5][350/1583] Loss_D: 1.7188  Loss_G: 0.5466  D(x): 0.2664    D(G(z)): 0.0539 / 0.6320
[3/5][400/1583] Loss_D: 0.5794  Loss_G: 2.7556  D(x): 0.7984    D(G(z)): 0.2640 / 0.0787
[3/5][450/1583] Loss_D: 0.6916  Loss_G: 3.1434  D(x): 0.8813    D(G(z)): 0.3955 / 0.0578
[3/5][500/1583] Loss_D: 0.8415  Loss_G: 1.9770  D(x): 0.6981    D(G(z)): 0.3120 / 0.1639
[3/5][550/1583] Loss_D: 0.6394  Loss_G: 2.4790  D(x): 0.8093    D(G(z)): 0.2990 / 0.1082
[3/5][600/1583] Loss_D: 0.7545  Loss_G: 1.6259  D(x): 0.6042    D(G(z)): 0.1454 / 0.2401
[3/5][650/1583] Loss_D: 0.5494  Loss_G: 2.1957  D(x): 0.8292    D(G(z)): 0.2727 / 0.1414
[3/5][700/1583] Loss_D: 1.5095  Loss_G: 5.1368  D(x): 0.9269    D(G(z)): 0.6897 / 0.0095
[3/5][750/1583] Loss_D: 0.4714  Loss_G: 2.1401  D(x): 0.8137    D(G(z)): 0.2101 / 0.1501
[3/5][800/1583] Loss_D: 0.7118  Loss_G: 3.2356  D(x): 0.8190    D(G(z)): 0.3579 / 0.0540
[3/5][850/1583] Loss_D: 0.6392  Loss_G: 1.6740  D(x): 0.6650    D(G(z)): 0.1402 / 0.2391
[3/5][900/1583] Loss_D: 0.5303  Loss_G: 2.8854  D(x): 0.7900    D(G(z)): 0.2204 / 0.0740
[3/5][950/1583] Loss_D: 0.6333  Loss_G: 2.1030  D(x): 0.6946    D(G(z)): 0.1882 / 0.1572
[3/5][1000/1583]        Loss_D: 0.8715  Loss_G: 1.6630  D(x): 0.5222    D(G(z)): 0.0890 / 0.2590
[3/5][1050/1583]        Loss_D: 0.6139  Loss_G: 3.1772  D(x): 0.8609    D(G(z)): 0.3400 / 0.0558
[3/5][1100/1583]        Loss_D: 0.6673  Loss_G: 3.4143  D(x): 0.9044    D(G(z)): 0.3910 / 0.0435
[3/5][1150/1583]        Loss_D: 0.6554  Loss_G: 3.4282  D(x): 0.8429    D(G(z)): 0.3347 / 0.0484
[3/5][1200/1583]        Loss_D: 0.6184  Loss_G: 1.7371  D(x): 0.6531    D(G(z)): 0.1177 / 0.2132
[3/5][1250/1583]        Loss_D: 0.8293  Loss_G: 3.1246  D(x): 0.7821    D(G(z)): 0.3883 / 0.0594
[3/5][1300/1583]        Loss_D: 0.5211  Loss_G: 2.0112  D(x): 0.7308    D(G(z)): 0.1503 / 0.1637
[3/5][1350/1583]        Loss_D: 0.7389  Loss_G: 1.4238  D(x): 0.5854    D(G(z)): 0.1181 / 0.2935
[3/5][1400/1583]        Loss_D: 0.6608  Loss_G: 3.1928  D(x): 0.7803    D(G(z)): 0.2922 / 0.0580
[3/5][1450/1583]        Loss_D: 0.6381  Loss_G: 3.4123  D(x): 0.8340    D(G(z)): 0.3337 / 0.0450
[3/5][1500/1583]        Loss_D: 0.7027  Loss_G: 3.1943  D(x): 0.9058    D(G(z)): 0.4113 / 0.0556
[3/5][1550/1583]        Loss_D: 0.6849  Loss_G: 2.9714  D(x): 0.8258    D(G(z)): 0.3499 / 0.0704
[4/5][0/1583]   Loss_D: 0.7685  Loss_G: 1.7204  D(x): 0.5788    D(G(z)): 0.1084 / 0.2252
[4/5][50/1583]  Loss_D: 0.6194  Loss_G: 1.4702  D(x): 0.6214    D(G(z)): 0.0700 / 0.2812
[4/5][100/1583] Loss_D: 0.5243  Loss_G: 2.4332  D(x): 0.8206    D(G(z)): 0.2515 / 0.1099
[4/5][150/1583] Loss_D: 0.8506  Loss_G: 1.0129  D(x): 0.5094    D(G(z)): 0.0647 / 0.4126
[4/5][200/1583] Loss_D: 1.1715  Loss_G: 2.5120  D(x): 0.5642    D(G(z)): 0.3481 / 0.1214
[4/5][250/1583] Loss_D: 0.4317  Loss_G: 2.7731  D(x): 0.8405    D(G(z)): 0.2088 / 0.0791
[4/5][300/1583] Loss_D: 1.2310  Loss_G: 0.4177  D(x): 0.3812    D(G(z)): 0.0576 / 0.6799
[4/5][350/1583] Loss_D: 0.5565  Loss_G: 2.7405  D(x): 0.8525    D(G(z)): 0.3005 / 0.0810
[4/5][400/1583] Loss_D: 0.4918  Loss_G: 3.5705  D(x): 0.8863    D(G(z)): 0.2833 / 0.0371
[4/5][450/1583] Loss_D: 0.6403  Loss_G: 2.7691  D(x): 0.8543    D(G(z)): 0.3406 / 0.0812
[4/5][500/1583] Loss_D: 0.5944  Loss_G: 1.4696  D(x): 0.6849    D(G(z)): 0.1325 / 0.2682
[4/5][550/1583] Loss_D: 0.8678  Loss_G: 4.1990  D(x): 0.9529    D(G(z)): 0.5105 / 0.0202
[4/5][600/1583] Loss_D: 0.8326  Loss_G: 1.1841  D(x): 0.5175    D(G(z)): 0.0679 / 0.3628
[4/5][650/1583] Loss_D: 0.5198  Loss_G: 2.4393  D(x): 0.7668    D(G(z)): 0.1943 / 0.1148
[4/5][700/1583] Loss_D: 0.8029  Loss_G: 4.0836  D(x): 0.8791    D(G(z)): 0.4448 / 0.0229
[4/5][750/1583] Loss_D: 0.8636  Loss_G: 2.0386  D(x): 0.5234    D(G(z)): 0.0899 / 0.1846
[4/5][800/1583] Loss_D: 0.5041  Loss_G: 3.0354  D(x): 0.8302    D(G(z)): 0.2301 / 0.0609
[4/5][850/1583] Loss_D: 0.7514  Loss_G: 1.2513  D(x): 0.5578    D(G(z)): 0.0899 / 0.3480
[4/5][900/1583] Loss_D: 0.6650  Loss_G: 1.2806  D(x): 0.6675    D(G(z)): 0.1925 / 0.3201
[4/5][950/1583] Loss_D: 0.5754  Loss_G: 3.0898  D(x): 0.8730    D(G(z)): 0.3233 / 0.0597
[4/5][1000/1583]        Loss_D: 0.9327  Loss_G: 0.7588  D(x): 0.4674    D(G(z)): 0.0434 / 0.5174
[4/5][1050/1583]        Loss_D: 0.9255  Loss_G: 0.9513  D(x): 0.5029    D(G(z)): 0.1161 / 0.4196
[4/5][1100/1583]        Loss_D: 0.6573  Loss_G: 3.4663  D(x): 0.8755    D(G(z)): 0.3674 / 0.0403
[4/5][1150/1583]        Loss_D: 0.9803  Loss_G: 1.2451  D(x): 0.4602    D(G(z)): 0.0978 / 0.3432
[4/5][1200/1583]        Loss_D: 0.5560  Loss_G: 2.5421  D(x): 0.7617    D(G(z)): 0.2097 / 0.1020
[4/5][1250/1583]        Loss_D: 0.7573  Loss_G: 1.9034  D(x): 0.6477    D(G(z)): 0.2158 / 0.1890
[4/5][1300/1583]        Loss_D: 0.4733  Loss_G: 2.7071  D(x): 0.8271    D(G(z)): 0.2169 / 0.0882
[4/5][1350/1583]        Loss_D: 1.0812  Loss_G: 1.1500  D(x): 0.5225    D(G(z)): 0.2278 / 0.3626
[4/5][1400/1583]        Loss_D: 1.5454  Loss_G: 5.2881  D(x): 0.9620    D(G(z)): 0.7085 / 0.0089
[4/5][1450/1583]        Loss_D: 0.3576  Loss_G: 3.1023  D(x): 0.8687    D(G(z)): 0.1726 / 0.0584
[4/5][1500/1583]        Loss_D: 0.5330  Loss_G: 1.9979  D(x): 0.7277    D(G(z)): 0.1597 / 0.1680
[4/5][1550/1583]        Loss_D: 0.8927  Loss_G: 4.1379  D(x): 0.9345    D(G(z)): 0.5081 / 0.0224

5.6 結果

從三個不同方面看實驗結果:

  • 看G和D兩個損失函數的變化
  • 看每輪epoch訓練G生成圖片的結果
  • 對比一批生成圖片和一批真實圖片(64張)

a.loss變化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

在這裏插入圖片描述
b.圖片生成變化

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

c.對比真假圖片

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

在這裏插入圖片描述

6.下一步

  • Train for longer to see how good the results get

多訓練幾次,如增加epoch看效果

  • Modify this model to take a different dataset and possibly change the size of the images and the model architecture

換其他數據集、或者調整一些模型結構

  • Check out some other cool GAN projects here

試試其他有趣的GAN應用–https://github.com/nashory/gans-awesome-applications

  • Create GANs that generate music

用GAN生成音樂

7.參考:

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#

https://github.com/soumith/ganhacks#authors

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章