tensorflow keras 關於CIFAR10數據集 CGAN的研究經驗總結

前言

目前發現本人的網絡可以在傳統的MNIST手寫數據集上有良好的表現,但是將其應用於CIFAR10數據集的時候,出現了非常嚴重的圖像模糊行爲,在實驗了多種傳統GAN的結構後,我的結論是傳統的GAN對於圖片的細節這些高頻信息生成的能力非常欠缺的,現在我總結一下前期的工作,鑑於網上的代碼比較少,我自己寫了一個。

實驗的結構

使用了類似於語義分割的結構進行實驗,總之效果非常不好,雖然可以看出個大概的形狀但是對於細節方面過於模糊。下面的結構僅供作爲訓練生成器的反面教材。如果想要生成MNIST的手寫數據集的話,或者進行判別器的訓練的話,拿去用用倒是不錯。

代碼

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image

import cv2
import PIL
import json, os
import sys

import labelme
import labelme.utils as utils
import glob
import itertools


class GAN():
    def __init__(self,  # 定義全局變量
                 ):
        self.img_shape = (32, 32, 3)
        self.save_path = r'C:\Users\Administrator\Desktop\photo\CGAN.h5'
        self.img_path = r'C:\Users\Administrator\Desktop\photo'
        self.batch_size = 100
        self.test_size = 200
        self.sample_interval = 1
        self.epoch = 200
        self.num_classes = 10
        self.train_mode=0#0爲從頭訓練 1爲繼續訓練 不推薦繼續訓練 用本文的代碼訓練容易直接崩潰

        # 建立GAN模型的方法
        if self.train_mode==0:
            self.generator_model = self.build_generator()
            self.discriminator_model = self.build_discriminator()
            self.model = self.bulid_model()
        else:
            self.model = keras.models.load_model(self.save_path)
            self.generator_model = keras.Model(inputs=self.model.layers[1].input, outputs=self.model.layers[1].output)
            self.discriminator_model = keras.Model(inputs=self.model.layers[2].input, outputs=self.model.layers[2].output)
    def build_generator(self):  # 生成器

        input = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))

        c00 = layers.UpSampling2D((2, 2))(input)

        x = layers.Conv2D(256,(3,3),padding='same',activation='relu')(c00)
        x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)


        c10 = layers.UpSampling2D((2, 2))(x)
        c11 = layers.Conv2DTranspose(256, (8,8), strides=4, padding='same', activation='relu')(input)
        x = layers.concatenate([c10,c11],axis=-1)
        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)

        c20 = layers.UpSampling2D((2, 2))(x)
        c21 = layers.Conv2DTranspose(128, (16, 16), strides=8, padding='same', activation='relu')(input)
        x = layers.concatenate([c20,c21],axis=-1)
        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)


        c30 = layers.UpSampling2D((2, 2))(x)
        c31 = layers.Conv2DTranspose(64, (32, 32), strides=16, padding='same', activation='relu')(input)
        x = layers.concatenate([c30,c31],axis=-1)
        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)


        output = layers.Conv2D(self.img_shape[2], (1, 1), padding='same', activation='sigmoid')(x)

        model = keras.Model(inputs=input, outputs=output, name='generator')
        model.summary()
        return model

    def build_discriminator(self):  # 判別器

        input = keras.Input(shape=self.img_shape)

        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(input)
        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
        x = layers.MaxPooling2D(2, 2)(x)

        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.MaxPooling2D(2, 2)(x)

        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.MaxPooling2D(2, 2)(x)

        x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)


        x = layers.Flatten()(x)
        output = layers.Dense(self.num_classes+1, activation='softmax')(x)

        model = keras.Model(inputs=input, outputs=output, name='discriminator')
        model.summary()
        return model

    def bulid_model(self):  # 建立GAN模型
        inputs = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))
        img = self.generator_model(inputs)
        outputs = self.discriminator_model(img)
        model = keras.Model(inputs=inputs, outputs=outputs)
        return model

    def load_data(self):
        (train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
        train_images,test_images = (train_images /255),(test_images/255)

        if (self.img_shape[2]==1):
            train_images = np.expand_dims(train_images,  axis=-1)
            test_images = np.expand_dims(test_images, axis=-1)

        train_idx = np.arange(0,train_images.shape[0],1)
        print('img train number:', train_images.shape)
        print('img test number: ', test_images.shape)
        return train_images,train_labels,train_idx,test_images,test_labels

    def noise_generate(self,batch_size,labels):#噪聲生成器
        noise = np.random.normal(0, 1, (batch_size, int(self.img_shape[0]/16), int(self.img_shape[1]/16), self.num_classes))  # 生成標準的高斯分佈噪聲
        n_labels = keras.utils.to_categorical(labels, num_classes=self.num_classes)
        n_labels = n_labels.reshape(batch_size, 1, 1, self.num_classes)
        noise = noise * n_labels#相乘後只有特定標籤對應的層存在噪聲
        return noise
    def compile(self):#模型編譯

        self.discriminator_model.compile(loss='categorical_crossentropy',
                                         optimizer=keras.optimizers.Adam(0.0001, 0.00001),
                                         metrics=['categorical_accuracy'])
        self.discriminator_model.trainable = False  # 使判別器不訓練
        if self.train_mode ==1:
            self.model=self.bulid_model()

            print('continue train')
        self.model.summary()
        self.model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.00001), loss='categorical_crossentropy', )

    def train(self):
        self.compile()
        train_images,train_labels,train_idx,test_images,test_labels = self.load_data()  # 讀取數據
        fake = np.ones((self.batch_size))*(self.num_classes)#判別爲假 對應12
        fake = keras.utils.to_categorical(fake,num_classes=self.num_classes+1)
        step = int(train_images.shape[0] / self.batch_size)  # 計算步長
        print('step:', step)

        for epoch in range(self.epoch):
            train_idx = (tf.random.shuffle(train_idx)).numpy()  # 每個epoch打亂一次
            print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
            if epoch % self.sample_interval == 0:#保存圖片
                self.generate_sample_images(epoch)

            for i in range(step):
                idx = train_idx[i * self.batch_size:i * self.batch_size + self.batch_size]  # 生成索引
                imgs = train_images[idx]  # 讀取索引對應的圖片
                labels = train_labels[idx]

                #---------------------------------------------------------------生成標準的高斯分佈噪聲
                noise = self.noise_generate(self.batch_size,labels)
                gan_imgs = self.generator_model.predict(noise)  # 通過噪聲生成圖片

                # ---------------------------------------------------------------
                labels=keras.utils.to_categorical(labels,num_classes=self.num_classes+1)#生成標籤
                total_imgs=tf.concat((gan_imgs,imgs),axis=0)
                total_labels=tf.concat((fake,labels),axis=0)

                # ----------------------------------------------訓練判別器

                discriminator_loss = self.discriminator_model.train_on_batch(total_imgs, total_labels)

                # ----------------------------------------------- 訓練生成器
                generator_loss = self.model.train_on_batch(noise, labels)

                print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
                    epoch, i, discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))

            # print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
            self.model.save(self.save_path)  # 每個epoch存儲模型
            print('save model')
    def generate_sample_images(self, epoch):

        row, col = 2, 2 # 行列的數字
        labels = np.random.randint(0,self.num_classes,(row * col))

        noise = self.noise_generate(row * col,labels) # 生成噪聲
        gan_imgs = ((self.generator_model.predict(noise)))
        fig, axs = plt.subplots(row, col)  # 生成畫板
        idx = 0

        for i in range(row):
            for j in range(col):
                axs[i, j].imshow(gan_imgs[idx, :, :]) #cmap='gray')
                axs[i, j].axis('off')
                idx += 1
        fig.savefig(self.img_path + "/%d.png" % epoch)
        plt.close()  # 關閉畫板

    def pred(self,mode,test_images=None,test_labels=None):#定義如何使用網絡的函數
        if (mode==0):#測試
            model = keras.models.load_model(self.save_path)
            print('loading model')
            generator = keras.Model(inputs=model.layers[1].input, outputs=model.layers[1].output)
            discriminator = keras.Model(inputs=model.layers[2].input, outputs=model.layers[2].output)
            generator.summary()
            discriminator.summary()
            for i in range(10):  #測試生成器
                label = i
                noise = self.noise_generate(1, label)
                img = np.squeeze(generator.predict([noise]))
                plt.imshow(img)
                plt.show()
        elif(mode==1):  #驗證集  讓fake層的值爲0查看判別器ACCC
            print('testing')
            step=int(test_images.shape[0]/self.test_size)
            val_acc=0
            for i in range(step):
                pred=self.discriminator_model(test_images[i*self.test_size:(i+1)*self.test_size])
                pred=pred.numpy()
                pred[:,(self.num_classes)]=0
                pred=tf.argmax(pred,axis=-1)
                pred=keras.utils.to_categorical(pred,num_classes=self.num_classes+1)
                labels=keras.utils.to_categorical(test_labels[i*self.test_size:(i+1)*self.test_size],num_classes=self.num_classes+1)
                acc=1-tf.reduce_mean(tf.abs(pred-labels))
                val_acc+=acc
            val_acc=val_acc/step
            return val_acc.numpy()
        else:
            pass

if __name__ == '__main__':
    GAN = GAN()
    GAN.train()
    GAN.pred(mode=0)#預測  0模式

效果

雖然在生成器上失敗了,但是在判別器的效果上做的很好,對於CIFAR10數據集驗證集的分類精度達到了96% 雖然沒有超過最高的97,但是我認爲已經超過很多常規的算法了。生成的圖片實在是不好意思拿出來,

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章