最近看了CGAN的論文,2014年的論文,短小精悍,CGAN可以用於圖像修補,多模態識別,感覺很有意思。抽空會把CGAN的論文理解也放上來。
論文下載地址:Conditional Generative Adversarial Nets
先放入全部代碼。來源:【Keras-CGAN】MNIST / CIFAR-10
代碼中噪聲Z和label、輸入圖片和label的combine機制和論文中不同,感覺沒有達到論文中的效果,不過也很好。但是論文中的機制很複雜,入門用這個就能跑出較好的效果。這份代碼的網絡結構是多層感知器比較簡單,沒有用上卷積層,如果採用DCGAN 的結構可能效果會更好。
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
# build_generator
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
model.add(Reshape((28, 28, 1)))
model.summary()
noise = Input(shape=(100,)) # input 100,這裏寫成100不加逗號不行喲
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label)) # class, z dimension
model_input = multiply([noise, label_embedding]) # 把 label 和 noise embedding 在一起,作爲 model 的輸入
print(model_input.shape)
img = model(model_input) # output (28,28,1)
generator = Model([noise, label], img)
# build_discriminator
model = Sequential()
model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=(28,28,1)) # 輸入 (28,28,1)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作爲 model 的輸入
discriminator = Model([img, label], validity)
#compile model
optimizer = Adam(0.0002, 0.5)
# discriminator
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# The combined model (stacked generator and discriminator)
noise = Input(shape=(100,))
label = Input(shape=(1,))
img = generator([noise,label])
# For the combined model we will only train the generator
validity = discriminator([img,label])
discriminator.trainable = False
# Trains the generator to fool the discriminator
combined = Model([noise,label], validity)
combined.summary()
combined.compile(loss='binary_crossentropy',
optimizer=optimizer)
def sample_images(epoch):
r, c = 2, 5
noise = np.random.normal(0, 1, (r * c, 100))
sampled_labels = np.arange(0, 10).reshape(-1, 1)
gen_imgs = generator.predict([noise, sampled_labels])
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/mnist%d.png" % epoch)
plt.close()
batch_size = 32
sample_interval = 200
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data() # (60000,28,28)
# Rescale -1 to 1
X_train = X_train / 127.5 - 1. # tanh 的結果是 -1~1,所以這裏 0-1 歸一化後減1
X_train = np.expand_dims(X_train, axis=3) # (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(50001):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size) # 0-60000 中隨機抽
#imgs = X_train[idx]
imgs, labels = X_train[idx], y_train[idx]
noise = np.random.normal(0, 1, (batch_size, 100))# 生成標準的高斯分佈噪聲
# Generate a batch of new images
gen_imgs = generator.predict([noise,labels])
# Train the discriminator
d_loss_real = discriminator.train_on_batch([imgs, labels], valid) #真實數據對應標籤1
d_loss_fake = discriminator.train_on_batch([gen_imgs,labels], fake) #生成的數據對應標籤0
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
#noise = np.random.normal(0, 1, (batch_size, 100))
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
# Train the generator (to have the discriminator label samples as valid)
g_loss = combined.train_on_batch([noise, sampled_labels], valid)
# Plot the progress
if epoch % 200==0:
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
sample_images(epoch)
主要解讀爲GAN 的G網和D網的輸入都添加條件信息的部分(add label)
1.G: G網的輸入噪聲z要結合label
model 定義了一個基於多層感知器的G網結構,然後
noise = Input(shape=(100,)) # input 100,這裏寫成100不加逗號不行喲
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label)) # class, z dimension
model_input = multiply([noise, label_embedding]) # 把 label 和 noise embedding 在一起,作爲 model 的輸入
print(model_input.shape)
img = model(model_input) # output (28,28,1)
generator = Model([noise, label], img)
主要是embedding層的理解,可查看官方文檔和相關博客。
- label_embedding 把“詞彙表”大小爲10的label(一共10個類別) 轉換爲100的向量維度,和noise維度一樣。
- Flatten層將輸入進行一維化
- Multiply層計算輸入張量列表的(逐元素間的)乘積。將label和噪聲Z結合。它接受一個張量的列表, 所有的張量必須有相同的輸入尺寸, 然後返回一個張量(和輸入張量尺寸相同)。因此,上一步把label轉爲和noise一樣維度。
- img = model(model_input) # output (28,28,1),生成一個圖片
- 由於以上的融合label的操作,使得G網的模型定義爲:
generator = Model([noise, label], img) #定義了最終G網的結構
- Keras有兩種類型的模型:序貫模型(Sequential)和函數式模型(Model)
- Model(inputs, outputs) generator = Model([noise, label], img)。 G網的輸出還是img大小(28,28,1)
2.D: D網的輸入img(真實or生成的圖片)要結合label
img = Input(shape=(28,28,1)) # 輸入 (28,28,1)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作爲 model 的輸入
discriminator = Model([img, label], validity) #定義了最終D網的結構
- np.prod()函數用來計算所有元素的乘積,所以label_embedding把“詞彙表”大小爲10的label(一共10個類別)轉換成了28*28*1維,Flatten 轉爲一維。Multiply也是將label和輸入的img結合。
- validity = model(model_input) 利用D網定義的model給輸入的label和圖片的結合進行打分,判斷真假。
- discriminator = Model([img, label], validity) #定義了最終D網的結構
- 剩下的model complie, combined model,訓練過程包括損失函數設計都和dcgan的設計一致,只是輸入的部分時候要加上label
- sampled_labels = np.arange(0, 10).reshape(-1, 1)
- gen_imgs = generator.predict([noise, sampled_labels])
- 結果是生成了label爲0-9的圖片。
最後放一張,50000次迭代後的生成圖片
由於這個代碼的G網,D網結構沒有采用卷積層,是多層感知器的結構(MLP),所以效果不太好,改成DCGAN 的結構可能效果會好很多。