Keras-CGAN_MNIST 代碼解讀

最近看了CGAN的論文,2014年的論文,短小精悍,CGAN可以用於圖像修補,多模態識別,感覺很有意思。抽空會把CGAN的論文理解也放上來。

論文下載地址:Conditional Generative Adversarial Nets

先放入全部代碼。來源:【Keras-CGAN】MNIST / CIFAR-10

代碼中噪聲Z和label、輸入圖片和label的combine機制和論文中不同,感覺沒有達到論文中的效果,不過也很好。但是論文中的機制很複雜,入門用這個就能跑出較好的效果。這份代碼的網絡結構是多層感知器比較簡單,沒有用上卷積層,如果採用DCGAN 的結構可能效果會更好。

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

# build_generator
model = Sequential()

model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))

model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))

model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))

model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
model.add(Reshape((28, 28, 1)))

model.summary()

noise = Input(shape=(100,))  # input 100,這裏寫成100不加逗號不行喲
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label))  # class, z dimension

model_input = multiply([noise, label_embedding])  # 把 label 和 noise embedding 在一起,作爲 model 的輸入
print(model_input.shape)

img = model(model_input)  # output (28,28,1)

generator = Model([noise, label], img)

# build_discriminator
model = Sequential()

model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))

model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))

model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))

model.add(Dense(1, activation='sigmoid'))
model.summary()

img = Input(shape=(28,28,1)) # 輸入 (28,28,1)
label = Input(shape=(1,), dtype='int32')

label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])

validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作爲 model 的輸入
discriminator = Model([img, label], validity)


#compile model
optimizer = Adam(0.0002, 0.5)

# discriminator
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])


# The combined model  (stacked generator and discriminator)
noise = Input(shape=(100,))
label = Input(shape=(1,))
img = generator([noise,label])

# For the combined model we will only train the generator
validity = discriminator([img,label])
discriminator.trainable = False

# Trains the generator to fool the discriminator
combined = Model([noise,label], validity)
combined.summary()
combined.compile(loss='binary_crossentropy',
                 optimizer=optimizer)

def sample_images(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    sampled_labels = np.arange(0, 10).reshape(-1, 1)
    gen_imgs = generator.predict([noise, sampled_labels])
    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist%d.png" % epoch)
    plt.close()

batch_size = 32
sample_interval = 200

# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data() # (60000,28,28)
# Rescale -1 to 1
X_train = X_train / 127.5 - 1. # tanh 的結果是 -1~1,所以這裏 0-1 歸一化後減1
X_train = np.expand_dims(X_train, axis=3)  # (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(50001):
    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Select a random batch of images
    idx = np.random.randint(0, X_train.shape[0], batch_size) # 0-60000 中隨機抽
    #imgs = X_train[idx]
    imgs, labels = X_train[idx], y_train[idx]
    noise = np.random.normal(0, 1, (batch_size, 100))# 生成標準的高斯分佈噪聲

    # Generate a batch of new images
    gen_imgs = generator.predict([noise,labels])

    # Train the discriminator
    d_loss_real = discriminator.train_on_batch([imgs, labels], valid) #真實數據對應標籤1
    d_loss_fake = discriminator.train_on_batch([gen_imgs,labels], fake) #生成的數據對應標籤0
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator
    # ---------------------
    #noise = np.random.normal(0, 1, (batch_size, 100))
    sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
    # Train the generator (to have the discriminator label samples as valid)
    g_loss = combined.train_on_batch([noise, sampled_labels], valid)

    # Plot the progress
    if epoch % 200==0:
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

    # If at save interval => save generated image samples
    if epoch % sample_interval == 0:
        sample_images(epoch)

主要解讀爲GAN 的G網和D網的輸入都添加條件信息的部分(add  label)

1.G: G網的輸入噪聲z要結合label

model  定義了一個基於多層感知器的G網結構,然後

noise = Input(shape=(100,))  # input 100,這裏寫成100不加逗號不行喲
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label))  # class, z dimension

model_input = multiply([noise, label_embedding])  # 把 label 和 noise embedding 在一起,作爲 model 的輸入
print(model_input.shape)

img = model(model_input)  # output (28,28,1)

generator = Model([noise, label], img)

主要是embedding層的理解,可查看官方文檔和相關博客。

  • label_embedding 把“詞彙表”大小爲10的label(一共10個類別) 轉換爲100的向量維度,和noise維度一樣。
  • Flatten層將輸入進行一維化
  • Multiply層計算輸入張量列表的(逐元素間的)乘積。將label和噪聲Z結合。它接受一個張量的列表, 所有的張量必須有相同的輸入尺寸, 然後返回一個張量(和輸入張量尺寸相同)。因此,上一步把label轉爲和noise一樣維度。
  • img = model(model_input)  # output (28,28,1),生成一個圖片
  • 由於以上的融合label的操作,使得G網的模型定義爲:

generator = Model([noise, label], img)   #定義了最終G網的結構

  • Keras有兩種類型的模型:序貫模型(Sequential)和函數式模型(Model)
  • Model(inputs, outputs)  generator = Model([noise, label], img)。 G網的輸出還是img大小(28,28,1)

2.D:    D網的輸入img(真實or生成的圖片)要結合label

img = Input(shape=(28,28,1)) # 輸入 (28,28,1)
label = Input(shape=(1,), dtype='int32')

label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])

validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作爲 model 的輸入
discriminator = Model([img, label], validity)  #定義了最終D網的結構
  • np.prod()函數用來計算所有元素的乘積,所以label_embedding把“詞彙表”大小爲10的label(一共10個類別)轉換成了28*28*1維,Flatten 轉爲一維。Multiply也是將label和輸入的img結合。
  • validity = model(model_input)  利用D網定義的model給輸入的label和圖片的結合進行打分,判斷真假。
  • discriminator = Model([img, label], validity)  #定義了最終D網的結構
  • 剩下的model complie, combined model,訓練過程包括損失函數設計都和dcgan的設計一致,只是輸入的部分時候要加上label
  • sampled_labels = np.arange(0, 10).reshape(-1, 1)
  • gen_imgs = generator.predict([noise, sampled_labels])
  • 結果是生成了label爲0-9的圖片。

最後放一張,50000次迭代後的生成圖片

由於這個代碼的G網,D網結構沒有采用卷積層,是多層感知器的結構(MLP),所以效果不太好,改成DCGAN 的結構可能效果會好很多。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章