前言
鑑於網上對於GAN網絡代碼的結構不太好,甚至無法做到迭代,我決定就GAN網絡來寫一個擁有能夠一目瞭然的完整結構的代碼,以幫助那些和我一樣剛開始接觸這類網絡的人,本篇中的GAN網絡由全連接層組成,以此來複現最簡單的GAN網絡結構。
一、代碼結構
代碼由全局量、生成器、判別器、GAN網絡、訓練、範例圖片生成以及載入模型生成圖片這幾個結構組成
二、代碼
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools
class GAN():
def __init__(self, #定義全局變量
):
self.img_shape = (28, 28, 1)
self.save_path = r'C:\Users\Administrator\Desktop\photo\GAN.h5'
self.img_path = r'C:\Users\Administrator\Desktop\photo'
self.batch_size = 20
self.latent_dim = 100
self.sample_interval=1
self.epoch=100
#建立GAN模型的方法
self.generator_model = self.build_generator()
self.discriminator_model = self.build_discriminator()
self.model = self.bulid_model()
def build_generator(self):#生成器
input=keras.Input(shape=self.latent_dim)
x=layers.Dense(256)(input)
x=layers.LeakyReLU(alpha=0.2)(x)
x=layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(512)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(1024)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x=layers.Dense(np.prod(self.img_shape),activation='sigmoid')(x)
output=layers.Reshape(self.img_shape)(x)
model=keras.Model(inputs=input,outputs=output,name='generator')
model.summary()
return model
def build_discriminator(self):#判別器
input=keras.Input(shape=self.img_shape)
x=layers.Flatten(input_shape=self.img_shape)(input)
x=layers.Dense(512)(x)
x=layers.LeakyReLU(alpha=0.2)(x)
x=layers.Dense(256)(x)
x=layers.LeakyReLU(alpha=0.2)(x)
output=layers.Dense(1,activation='sigmoid')(x)
model=keras.Model(inputs=input,outputs=output,name='discriminator')
model.summary()
return model
def bulid_model(self):#建立GAN模型
self.discriminator_model.compile(loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(0.0001, 0.000001),
metrics=['accuracy'])
self.discriminator_model.trainable = False#使生成器不訓練
inputs = keras.Input(shape=self.latent_dim)
img = self.generator_model(inputs)
outputs = self.discriminator_model(img)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.000001),
loss='binary_crossentropy',
)
return model
def load_data(self):
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images /255
train_images = np.expand_dims(train_images, axis=3)
print('img_number:',train_images.shape)
return train_images
def train(self):
train_images=self.load_data()#讀取數據
#生成標籤
valid = np.ones((self.batch_size, 1))
fake = np.zeros((self.batch_size, 1))
step=int(train_images.shape[0]/self.batch_size)#計算步長
print('step:',step)
for epoch in range(self.epoch):
train_images = (tf.random.shuffle(train_images)).numpy()#每個epoch打亂一次
if epoch % self.sample_interval == 0:
self.generate_sample_images(epoch)
for i in range(step):
idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引
imgs =train_images[idx]#讀取索引對應的圖片
noise = np.random.normal(0, 1, (self.batch_size, 100)) # 生成標準的高斯分佈噪聲
gan_imgs = self.generator_model.predict(noise)#通過噪聲生成圖片
#----------------------------------------------訓練判別器
discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid) # 真實數據對應標籤1
discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake) # 生成的數據對應標籤0
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
#----------------------------------------------- 訓練生成器
noise = np.random.normal(0, 1, (self.batch_size, 100))
generator_loss = self.model.train_on_batch(noise, valid)
if i%10==0:#每十步進行輸出
print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))
self.model.save(self.save_path)#存儲模型
def generate_sample_images(self, epoch):#生成圖片
row, col = 5, 5#行列的數字
noise = np.random.normal(0, 1, (row * col, self.latent_dim))#生成噪聲
gan_imgs = self.generator_model.predict(noise)
fig, axs = plt.subplots(row, col)#生成5*5的畫板
idx = 0
for i in range(row):
for j in range(col):
axs[i, j].imshow(gan_imgs[idx, :, :, 0], cmap='gray')
axs[i, j].axis('off')
idx += 1
fig.savefig(self.img_path+"/%d.png" % epoch)
plt.close()#關閉畫板
def pred(self):#載入模型並生成圖片
model=keras.models.load_model(self.save_path)
model.summary()
noise = np.random.normal(0, 1, (1, self.latent_dim))
generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output)
generator.summary()
img=np.squeeze(generator.predict([noise]))
plt.imshow(img)
plt.show()
print(img.shape)
if __name__ == '__main__':
GAN = GAN()
GAN.train()
三、實驗現象
可以看到網絡生成的數字越來越逼真了