【轉載】DCGAN及其TensorFlow源碼

上一節我們提到G和D由多層感知機定義。深度學習中對圖像處理應用最好的模型是CNN,那麼如何把CNN與GAN結合?DCGAN是這方面最好的嘗試之一。源碼:https://github.com/Newmu/dcgan_code 。DCGAN論文作者用theano實現的,他還放上了其他人實現的版本,本文主要討論tensorflow版本。
TensorFlow版本的源碼:https://github.com/carpedm20/DCGAN-tensorflow

DCGAN把上述的G和D換成了兩個卷積神經網絡(CNN)。但不是直接換就可以了,DCGAN對卷積神經網絡的結構做了一些改變,以提高樣本的質量和收斂的速度,這些改變有:

  • 取消所有pooling層。G網絡中使用轉置卷積(transposed convolutional layer)進行上採樣,D網絡中用加入strided的卷積代替pooling。
  • 在D和G中均使用batch normalization
  • 去掉FC層,使網絡變爲全卷積網絡
  • G網絡中使用ReLU作爲激活函數,最後一層使用tanh
  • D網絡中使用LeakyReLU作爲激活函數

這些改變在代碼中都可以看到。DCGAN論文中提到對CNN結構有三點重要的改變:

  1. Allconvolutional net (Springenberg et al., 2014) 全卷積網絡
    判別模型D:使用帶步長的卷積(strided convolutions)取代了的空間池化(spatial pooling),容許網絡學習自己的空間下采樣(spatial downsampling)。
    Ÿ 生成模型G:使用微步幅卷積(fractional strided),容許它學習自己的空間上採樣(spatial upsampling)
  2. 在卷積特徵之上消除全連接層。
    Ÿ (Mordvintsev et al.)提出的全局平均池化有助於模型的穩定性,但損害收斂速度。
    GAN的第一層輸入:服從均勻分佈的噪聲向量Z,因爲只有矩陣乘法,因此可以被叫做全連接層,但結果會被reshape成4維張量,作爲卷積棧的開始。
    對於D,最後的卷積層被flatten(把矩陣變成向量),然後使用sigmoid函數處理輸出。
    生成模型:輸出層用Tanh函數,其它層用ReLU激活函數。
    判別模型:所有層使用LeakyReLU
  3. Batch Normalization 批標準化。
    解決因糟糕的初始化引起的訓練問題,使得梯度能傳播更深層次。穩定學習,通過歸一化輸入的單元,使它們平均值爲0,具有單位方差。
    批標準化證明了生成模型初始化的重要性,避免生成模型崩潰:生成的所有樣本都在一個點上(樣本相同),這是訓練GANs經常遇到的失敗現象。
    generator:100維的均勻分佈Z投影到小的空間範圍卷積表示,產生許多特徵圖。一系列四步卷積將這個表示轉換爲64x64像素的圖像。不用到完全連接或者池化層。

配置

Python
TensorFlow
SciPy
pillow
(可選)moviepy (https://github.com/Zulko/moviepy):用於可視化
(可選)Align&Cropped Images.zip (http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html):人臉數據集

main.py

入口程序,事先定義所需參數的值。
執行程序:
訓練一個模型:
$ python main.py --dataset mnist --is_trainTrue
$ python main.py --dataset celebA --is_trainTrue --is_crop True
測試一個已存在模型:
$ python main.py --dataset mnist
$ python main.py --dataset celebA --is_crop True
你也可以使用自己的dataset:
$ mkdir data/DATASET_NAME
添加圖片到data/DATASET_NAME …
$ python main.py --dataset DATASET_NAME--is_train True
$ python main.py --dataset DATASET_NAME
訓練出多張以假亂真的圖片

源碼分析

flags配置network的參數,在命令行中可以修改,比如
$python main.py --image_size 96 --output_size 48 --dataset anime --is_crop True--is_train True --epoch 300
該套代碼參數主要以mnist數據集爲模板,如果要訓練別的數據集,可以適當修改一些參數。mnist數據集可以通過download.py下載。
首先初始化model.py中的DCGAN,然後看是否需要訓練(is_train)。

FLAGS參數

epochepoch:可視化爲True,不可視化爲False,默認爲False

model.py

初始化參數

model.py定義了DCGAN類,包括9個函數

__init__()

參數初始化,已講過的input_height, input_width, crop, batch_size, output_height, output_width, dataset_name, input_fname_pattern, checkpoint_dir, sample_dir就不再說了
sample_numsample_num:顏色通道,灰度圖像設爲1,彩色圖像設爲3,默認爲3
其中self.d_bn1, self.d_bn2, g_bn0, g_bn1, g_bn2是batch標準化,見ops.py的batch_norm(object)。
如果是mnist數據集,d_bn3, g_bn3都要batch_norm。
self.data讀取數據集。
然後建立模型(build_model)

build_model()

inputs的形狀爲[batch_size, input_height, input_width, c_dim]。
如果crop=True,inputs的形狀爲[batch_size, output_height, output_width, c_dim]。
輸入分爲樣本輸入inputs和抽樣輸入sample_inputs。
噪聲z的形狀爲[None, z_dim],第一個None是batch的大小。
然後取數據:
self.G = self.generator(self.z)#返回[batch_size, output_height, output_width, c_dim]形狀的張量,也就是batch_size張圖
self.D, self.D_logits = self.discriminator(inputs)#返回的D爲是否是真樣本的sigmoid概率,D_logits是未經sigmoid處理
self.sampler = self.sampler(self.z)#相當於測試,經過G網絡模型,取樣,代碼和G很像,沒有G訓練的過程。
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)
#D是真實數據,D_是假數據
用交叉熵計算損失,共有:d_loss_real、d_loss_fake、g_loss
self.d_loss_real = tf.reduce_mean(
sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))

self.d_loss_fake = tf.reduce_mean(
sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))

self.g_loss = tf.reduce_mean(
sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))

tf.ones_like:新建一個與給定tensor大小一致的tensor,其全部元素爲1
d_loss_real是真樣本輸入的損失,要讓D_logits接近於1,也就是D識別出真樣本爲真的
d_loss_fake是假樣本輸入的損失,要讓D_logits_接近於0,D識別出假樣本爲假
d_loss = d_loss_real + d_loss_fake是D的目標,要最小化這個損失
g_loss:要讓D識別假樣本爲真樣本,G的目標是降低這個損失,D是提高這個損失

summary這幾步是關於可視化,就不管了

train()

通過Adam優化器最小化d_loss和g_loss。
sample_z爲從-1到1均勻分佈的數,大小爲[sample_num, z_dim]
從路徑中讀取原始樣本sample,大小爲[sample_num, output_height, output_width, c_dim]
接下來進行epoch個訓練:
將data總數分爲batch_idxs次訓練,每次訓練batch_size個樣本。產生的樣本爲batch_images。
batch_z爲訓練的噪聲,大小爲[batch_num, z_dim]
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)

g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)

首先輸入噪聲z和batch_images,通過優化d_optim更新D網絡。
然後輸入噪聲z,優化g_optim來更新G網絡。G網絡更新兩次,以免d_loss爲0。這點不同於paper。
這樣的訓練,每過100個可以生成圖片看看效果。
if np.mod(counter, 100) == 1

discriminator()

代碼自定義了一個conv2d,對tf.nn.conv2d稍加修改了。下面貼出tf.nn.conv2dtf.nn.conv2d
tf.contrib.layers.batch_norm的代碼見https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py
batchnormalization來自於http://arxiv.org/abs/1502.03167
加快訓練。
這裏寫圖片描述
激活函數lrelu見ops.py。四次卷積(其中三次卷積之前先批標準化)和激活之後。然後線性化,返回sigmoid函數處理後的結果。h3到h4的全連接相當於線性化,用一個矩陣將h3和h4連接起來,使h4是一個batch_size維的向量。

generator()

self.h0 = tf.reshape(self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])改變z_的形狀。-1代表的含義是不用我們自己指定這一維的大小,函數會自動計算,但列表中只能存在一個-1。(當然如果存在多個-1,就是一個存在多解的方程了)
deconv2d()deconv2d()
引用tf的反捲積函數tf.nn.conv2d_transpose或tf.nn.deconv2d。以tf.nn.conv2d_transpose爲例。
defconv2d_transpose(value, filter, output_shape, strides,padding=”SAME”, data_format=”NHWC”, name=None):

  • value: 是一個4維的tensor,格式爲[batch, height, width, in_channels] 或者 [batch, in_channels,height, width]。
  • filter: 是一個4維的tensor,格式爲[height, width, output_channels, in_channels],過濾器的in_ channels的維度要和這個匹配。
  • output_shape: 一維tensor,表示反捲積操作的輸出shapeA
  • strides: 針對每個輸入的tensor維度,滑動窗口的步長。
  • padding: “VALID”或者”SAME”,padding算法
  • data_format: “NHWC”或者”NCHW” ,對應value的數據格式。
  • name: 可選,返回的tensor名。

deconv= tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,strides=[1,d_h, d_w, 1])
第一個參數是輸入,即上一層的結果,
第二個參數是輸出輸出的特徵圖維數,是個4維的參數,
第三個參數卷積核的移動步長,[1, d_h, d_w, 1],其中第一個對應一次跳過batch中的多少圖片,第二個d_h對應一次跳過圖片中多少行,第三個d_w對應一次跳過圖片中多少列,第四個對應一次跳過圖像的多少個通道。這裏直接設置爲[1,2,2,1]。即每次反捲積後,圖像的滑動步長爲2,特徵圖會擴大縮小爲原來2*2=4倍。
這裏寫圖片描述

sampler()

和generator結構一樣,用的也是它的參數。存在的意義可能在於共享參數?
self.sampler = self.sampler(self.z, self.y)改爲self.sampler = self.generator(self.z, self.y)
報錯:
這裏寫圖片描述
所以sampler的存在還是有意義的。

load_mnist(), save(), load()
這三個加載保存等就不仔細講了。

download.py和ops.py好像也沒什麼好講的。
utils.py包含可視化等函數

參考:
Springenberg, Jost Tobias, Dosovitskiy, Alexey, Brox, Thomas, and Riedmiller, Martin. Striving for simplicity: The all convolutional net. arXiv preprint arXiv:1412.6806, 2014.
Mordvintsev, Alexander, Olah, Christopher, and Tyka, Mike. Inceptionism : Going deeper into neural networks.http://googleresearch.blogspot.com/2015/06/inceptionism-going-deeper-into-neural.html. Accessed: 2015-06-17.
Radford A, Metz L, Chintala S. UnsupervisedRepresentation Learning with Deep Convolutional Generative AdversarialNetworks[J]. Computer Science, 2015.
http://blog.csdn.net/nongfu_spring/article/details/54342861
http://blog.csdn.net/solomon1558/article/details/52573596

發佈了13 篇原創文章 · 獲贊 39 · 訪問量 7900
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章