前言
目前發現本人的網絡可以在傳統的MNIST手寫數據集上有良好的表現,但是將其應用於CIFAR10數據集的時候,出現了非常嚴重的圖像模糊行爲,在實驗了多種傳統GAN的結構後,我的結論是傳統的GAN對於圖片的細節這些高頻信息生成的能力非常欠缺的,現在我總結一下前期的工作,鑑於網上的代碼比較少,我自己寫了一個。
實驗的結構
使用了類似於語義分割的結構進行實驗,總之效果非常不好,雖然可以看出個大概的形狀但是對於細節方面過於模糊。下面的結構僅供作爲訓練生成器的反面教材。如果想要生成MNIST的手寫數據集的話,或者進行判別器的訓練的話,拿去用用倒是不錯。
代碼
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools
class GAN():
def __init__(self, # 定義全局變量
):
self.img_shape = (32, 32, 3)
self.save_path = r'C:\Users\Administrator\Desktop\photo\CGAN.h5'
self.img_path = r'C:\Users\Administrator\Desktop\photo'
self.batch_size = 100
self.test_size = 200
self.sample_interval = 1
self.epoch = 200
self.num_classes = 10
self.train_mode=0#0爲從頭訓練 1爲繼續訓練 不推薦繼續訓練 用本文的代碼訓練容易直接崩潰
# 建立GAN模型的方法
if self.train_mode==0:
self.generator_model = self.build_generator()
self.discriminator_model = self.build_discriminator()
self.model = self.bulid_model()
else:
self.model = keras.models.load_model(self.save_path)
self.generator_model = keras.Model(inputs=self.model.layers[1].input, outputs=self.model.layers[1].output)
self.discriminator_model = keras.Model(inputs=self.model.layers[2].input, outputs=self.model.layers[2].output)
def build_generator(self): # 生成器
input = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))
c00 = layers.UpSampling2D((2, 2))(input)
x = layers.Conv2D(256,(3,3),padding='same',activation='relu')(c00)
x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
c10 = layers.UpSampling2D((2, 2))(x)
c11 = layers.Conv2DTranspose(256, (8,8), strides=4, padding='same', activation='relu')(input)
x = layers.concatenate([c10,c11],axis=-1)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
c20 = layers.UpSampling2D((2, 2))(x)
c21 = layers.Conv2DTranspose(128, (16, 16), strides=8, padding='same', activation='relu')(input)
x = layers.concatenate([c20,c21],axis=-1)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
c30 = layers.UpSampling2D((2, 2))(x)
c31 = layers.Conv2DTranspose(64, (32, 32), strides=16, padding='same', activation='relu')(input)
x = layers.concatenate([c30,c31],axis=-1)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
output = layers.Conv2D(self.img_shape[2], (1, 1), padding='same', activation='sigmoid')(x)
model = keras.Model(inputs=input, outputs=output, name='generator')
model.summary()
return model
def build_discriminator(self): # 判別器
input = keras.Input(shape=self.img_shape)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(input)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
x = layers.Flatten()(x)
output = layers.Dense(self.num_classes+1, activation='softmax')(x)
model = keras.Model(inputs=input, outputs=output, name='discriminator')
model.summary()
return model
def bulid_model(self): # 建立GAN模型
inputs = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))
img = self.generator_model(inputs)
outputs = self.discriminator_model(img)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def load_data(self):
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
train_images,test_images = (train_images /255),(test_images/255)
if (self.img_shape[2]==1):
train_images = np.expand_dims(train_images, axis=-1)
test_images = np.expand_dims(test_images, axis=-1)
train_idx = np.arange(0,train_images.shape[0],1)
print('img train number:', train_images.shape)
print('img test number: ', test_images.shape)
return train_images,train_labels,train_idx,test_images,test_labels
def noise_generate(self,batch_size,labels):#噪聲生成器
noise = np.random.normal(0, 1, (batch_size, int(self.img_shape[0]/16), int(self.img_shape[1]/16), self.num_classes)) # 生成標準的高斯分佈噪聲
n_labels = keras.utils.to_categorical(labels, num_classes=self.num_classes)
n_labels = n_labels.reshape(batch_size, 1, 1, self.num_classes)
noise = noise * n_labels#相乘後只有特定標籤對應的層存在噪聲
return noise
def compile(self):#模型編譯
self.discriminator_model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.Adam(0.0001, 0.00001),
metrics=['categorical_accuracy'])
self.discriminator_model.trainable = False # 使判別器不訓練
if self.train_mode ==1:
self.model=self.bulid_model()
print('continue train')
self.model.summary()
self.model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.00001), loss='categorical_crossentropy', )
def train(self):
self.compile()
train_images,train_labels,train_idx,test_images,test_labels = self.load_data() # 讀取數據
fake = np.ones((self.batch_size))*(self.num_classes)#判別爲假 對應12
fake = keras.utils.to_categorical(fake,num_classes=self.num_classes+1)
step = int(train_images.shape[0] / self.batch_size) # 計算步長
print('step:', step)
for epoch in range(self.epoch):
train_idx = (tf.random.shuffle(train_idx)).numpy() # 每個epoch打亂一次
print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
if epoch % self.sample_interval == 0:#保存圖片
self.generate_sample_images(epoch)
for i in range(step):
idx = train_idx[i * self.batch_size:i * self.batch_size + self.batch_size] # 生成索引
imgs = train_images[idx] # 讀取索引對應的圖片
labels = train_labels[idx]
#---------------------------------------------------------------生成標準的高斯分佈噪聲
noise = self.noise_generate(self.batch_size,labels)
gan_imgs = self.generator_model.predict(noise) # 通過噪聲生成圖片
# ---------------------------------------------------------------
labels=keras.utils.to_categorical(labels,num_classes=self.num_classes+1)#生成標籤
total_imgs=tf.concat((gan_imgs,imgs),axis=0)
total_labels=tf.concat((fake,labels),axis=0)
# ----------------------------------------------訓練判別器
discriminator_loss = self.discriminator_model.train_on_batch(total_imgs, total_labels)
# ----------------------------------------------- 訓練生成器
generator_loss = self.model.train_on_batch(noise, labels)
print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
epoch, i, discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))
# print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
self.model.save(self.save_path) # 每個epoch存儲模型
print('save model')
def generate_sample_images(self, epoch):
row, col = 2, 2 # 行列的數字
labels = np.random.randint(0,self.num_classes,(row * col))
noise = self.noise_generate(row * col,labels) # 生成噪聲
gan_imgs = ((self.generator_model.predict(noise)))
fig, axs = plt.subplots(row, col) # 生成畫板
idx = 0
for i in range(row):
for j in range(col):
axs[i, j].imshow(gan_imgs[idx, :, :]) #cmap='gray')
axs[i, j].axis('off')
idx += 1
fig.savefig(self.img_path + "/%d.png" % epoch)
plt.close() # 關閉畫板
def pred(self,mode,test_images=None,test_labels=None):#定義如何使用網絡的函數
if (mode==0):#測試
model = keras.models.load_model(self.save_path)
print('loading model')
generator = keras.Model(inputs=model.layers[1].input, outputs=model.layers[1].output)
discriminator = keras.Model(inputs=model.layers[2].input, outputs=model.layers[2].output)
generator.summary()
discriminator.summary()
for i in range(10): #測試生成器
label = i
noise = self.noise_generate(1, label)
img = np.squeeze(generator.predict([noise]))
plt.imshow(img)
plt.show()
elif(mode==1): #驗證集 讓fake層的值爲0查看判別器ACCC
print('testing')
step=int(test_images.shape[0]/self.test_size)
val_acc=0
for i in range(step):
pred=self.discriminator_model(test_images[i*self.test_size:(i+1)*self.test_size])
pred=pred.numpy()
pred[:,(self.num_classes)]=0
pred=tf.argmax(pred,axis=-1)
pred=keras.utils.to_categorical(pred,num_classes=self.num_classes+1)
labels=keras.utils.to_categorical(test_labels[i*self.test_size:(i+1)*self.test_size],num_classes=self.num_classes+1)
acc=1-tf.reduce_mean(tf.abs(pred-labels))
val_acc+=acc
val_acc=val_acc/step
return val_acc.numpy()
else:
pass
if __name__ == '__main__':
GAN = GAN()
GAN.train()
GAN.pred(mode=0)#預測 0模式
效果
雖然在生成器上失敗了,但是在判別器的效果上做的很好,對於CIFAR10數據集驗證集的分類精度達到了96% 雖然沒有超過最高的97,但是我認爲已經超過很多常規的算法了。生成的圖片實在是不好意思拿出來,