GAN生成對抗網絡合集(六):GAN-cls –具有匹配感知的判別器(附代碼)

1 GAN-cls原理

       這是一種GAN網絡增強技術----具有匹配感知的判別器。前面講過,在InfoGAN中,使用了ACGAN的方式進行指導模擬數據與生成數據的對應關係(分類)。在GAN-cls中該效果會以更簡單的方式來實現,即增強判別器的功能,令其不僅能判斷圖片真僞,還能判斷匹配真僞

(個人理解)沒啥實質性改變,時間並未縮短,技術也沒有怎麼簡化甚至變得複雜了。就是思想上的一個轉變,原本ACGan是模擬樣本+正確分類信息輸入進去/真實樣本+正確分類信息輸入進D去。現在的GAN-cls變爲輸入真實樣本和真實標籤、虛擬樣本和真實標籤、虛擬標籤和真實樣本的三種組合形式(無對應圖片的隨機標籤

       GAN-cls的具體做法是,在原有的GAN網絡上,將判別器的輸入變爲圖片與對應標籤的連接數據。這樣判別器的輸入特徵中就會有生成圖像的特徵與對應標籤的特徵。然後用這樣的判別器分別對真實標籤與真實圖片、假標籤與真實圖片、真實標籤與假圖片進行判斷,預期的結果依次爲真、假、假,在訓練的過程中沿着這個方向收斂即可。而對於生成器,則不需要做任何改動。這樣簡單的一步就完成了生成根據標籤匹配的模擬數據功能。

在這裏插入圖片描述

2 代碼

直接修改上一篇 GAN生成對抗網絡合集(五):LSGan-最小二乘GAN(附代碼) 代碼,將其改成GAN-cls。

  1. 修改判別器D
    將判別器的輸入改成x與y,新增加的y代表輸入的樣本標籤(真、假);在內部處理中,先通過全連接網絡將y變爲與圖片一樣維度的映射,並調整爲圖片相同的形狀,使用concat將二者連接到一起統一處理。後續的處理過程是一樣的,兩個卷積後再接兩個全連接,最後一層輸出disc。該部分代碼如下:
# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y):  # 判別器函數 : x兩次卷積,再接兩次全連接; y代表輸入的樣本標籤
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):

        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu)  # 將y變爲與圖片一樣維度的映射
        y = tf.reshape(y, shape=[-1, 28, 28, 1])    # 將y統一成圖片格式

        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # 將二者連接到一起,統一處理
        x = tf.concat(axis=3, values=[x, y])  # x.shape = [-1, 28, 28, 2]

        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 輸入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的數據可以分別連接不同的輸出層產生不同的結果
        # 1維的輸出層產生判別結果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10維的輸出層產生分類結果 (樣本標籤)
        # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2維輸出層產生重構造的隱含維度信息
        # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc  # recog_cat, recog_cont
  1. 添加錯誤標籤輸入符,構建網絡結構
    添加錯誤標籤misy,同時在判別器中分別將真實樣本與真實標籤、生成的圖像gen與真實標籤、真實樣本與錯誤標籤組成的輸入傳入判別器中。去掉隱含信息z_con部分。

注:這裏是將3種輸入的x與y分別按照batch_size維度連接變爲判別器的一個輸入的。生成結果後再使用split函數將其裁成3個結果disc_real、disc_fake和disc_mis,分別代表真實樣本與真實標籤、生成的圖像gen與真實標籤、真實樣本與錯誤標籤所對應的判別值。這麼寫會使代碼看上去簡潔一些,當然也可以一個一個地輸入x、y,然後調用三次判別器,效果是一樣的。

##################################################################
#  3.定義網絡模型 : 定義 參數/輸入/輸出/中間過程(經過G/D)的輸入輸出
##################################################################
batch_size = 10  # 獲取樣本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2  # 隱含信息變量的維度, 應節點爲z_con
rand_dim = 38  # 一般噪聲的維度, 應節點爲z_rand, 二者都是符合標準高斯分佈的隨機數。
n_input = 784  # 28 * 28

x = tf.placeholder(tf.float32, [None, n_input])  # x爲輸入真實圖片images
y = tf.placeholder(tf.int32, [None])  # y爲真實標籤labels
misy = tf.placeholder(tf.int32, [None])  # 錯誤標籤

# z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1)  # shape = (10, 28, 28)

# labels for discriminator
# y_real = tf.ones(batch_size)  # 真
# y_fake = tf.zeros(batch_size)  # 假

# 判別器D
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat([tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
# disc_real, class_real, _ = discriminator(x)
# disc_fake, class_fake, con_fake = discriminator(gen)
# pred_class = tf.argmax(class_fake, dimension=1)

disc_all = discriminator(xin, yin)
# 真實樣本與真實標籤、生成的圖像gen與真實標籤、真實樣本與錯誤標籤所對應的判別值
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)
  1. 修改loss值
    在計算判別器的loss時,同樣使用LSGAN方式,並且將錯誤部分的loss變爲disc_fake與disc_mis的和,然後再除以2。因爲對於生成器生成的樣本與錯誤的輸入標籤,判別器都應該將其判斷爲錯誤。
# 判別器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))  # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))  # 0

loss_d = tf.reduce_sum(tf.square(disc_real-1) + (tf.square(disc_fake-0)+tf.square(disc_mis-0))/2) / 2

# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2

# categorical factor loss 分類因素損失
# loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
# loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
# loss_c = (loss_cf + loss_cr) / 2

# continuous factor loss 隱含信息變量的損失
# loss_con = tf.reduce_mean(tf.square(con_fake - z_con))
  1. 使用MonitoredTrainingSession創建sesson,開始訓練
    定義global_step,使用MonitoredTrainingSession創建sesson,來管理檢查點文件,在session中構建錯誤標籤數據,訓練模型。
##################################################################
#  5.訓練與測試
#  建立session,循環中使用run來運行兩個優化器
##################################################################

training_epochs = 3
display_step = 1

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
    total_batch = int(mnist.train.num_examples / batch_size)
    print("global_step.eval(session=sess)", global_step.eval(session=sess),
          int(global_step.eval(session=sess) / total_batch))
    for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
        avg_cost = 0.

        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            _, mis_batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step], feeds)
            l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)

        # 顯示訓練中的詳細信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)

    print("完成!")

-----------------------------------------------------------------------------------------------------------------------------------------

附上全部代碼:

# !/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = '黎明'

##################################################################
#  1.引入頭文件並加載mnist數據
##################################################################
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim
import time
from timer import Timer
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/")  # ,one_hot=True)

tf.reset_default_graph()  # 用於清除默認圖形堆棧並重置全局默認圖形


##################################################################
#  2.定義生成器與判別器
##################################################################
def generator(x):  # 生成器函數 : 兩個全連接+兩個反捲積模擬樣本的生成,每一層都有BN(批量歸一化)處理
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0  # 確認該變量作用域沒有變量
    # print (x.get_shape())
    with tf.variable_scope('generator', reuse=reuse):
        x = slim.fully_connected(x, 1024)
        # print(x)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = slim.fully_connected(x, 7 * 7 * 128)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = tf.reshape(x, [-1, 7, 7, 128])
        # print ('22', tf.tensor.get_shape())
        x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn=None)
        # print ('gen',x.get_shape())
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
        # print ('genz',z.get_shape())
    return z


def leaky_relu(x):
    return tf.where(tf.greater(x, 0), x, 0.01 * x)


# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y):  # 判別器函數 : x兩次卷積,再接兩次全連接; y代表輸入的樣本標籤
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):

        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu)  # 將y變爲與圖片一樣維度的映射
        y = tf.reshape(y, shape=[-1, 28, 28, 1])    # 將y統一成圖片格式

        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # 將二者連接到一起,統一處理
        x = tf.concat(axis=3, values=[x, y])  # x.shape = [-1, 28, 28, 2]

        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 輸入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的數據可以分別連接不同的輸出層產生不同的結果
        # 1維的輸出層產生判別結果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10維的輸出層產生分類結果 (樣本標籤)
        # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2維輸出層產生重構造的隱含維度信息
        # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc  # recog_cat, recog_cont


##################################################################
#  3.定義網絡模型 : 定義 參數/輸入/輸出/中間過程(經過G/D)的輸入輸出
##################################################################
batch_size = 10  # 獲取樣本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2  # 隱含信息變量的維度, 應節點爲z_con
rand_dim = 38  # 一般噪聲的維度, 應節點爲z_rand, 二者都是符合標準高斯分佈的隨機數。
n_input = 784  # 28 * 28

x = tf.placeholder(tf.float32, [None, n_input])  # x爲輸入真實圖片images
y = tf.placeholder(tf.int32, [None])  # y爲真實標籤labels
misy = tf.placeholder(tf.int32, [None])  # 錯誤標籤

# z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1)  # shape = (10, 28, 28)

# labels for discriminator
# y_real = tf.ones(batch_size)  # 真
# y_fake = tf.zeros(batch_size)  # 假

# 判別器D
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat([tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
# disc_real, class_real, _ = discriminator(x)
# disc_fake, class_fake, con_fake = discriminator(gen)
# pred_class = tf.argmax(class_fake, dimension=1)

disc_all = discriminator(xin, yin)
# 真實樣本與真實標籤、生成的圖像gen與真實標籤、真實樣本與錯誤標籤所對應的判別值
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)

##################################################################
#  4.定義損失函數和優化器
##################################################################
# 判別器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))  # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))  # 0

loss_d = tf.reduce_sum(tf.square(disc_real-1) + (tf.square(disc_fake-0)+tf.square(disc_mis-0))/2) / 2

# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2

# categorical factor loss 分類因素損失
# loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
# loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
# loss_c = (loss_cf + loss_cr) / 2

# continuous factor loss 隱含信息變量的損失
# loss_con = tf.reduce_mean(tf.square(con_fake - z_con))

# 獲得各個網絡中各自的訓練參數列表
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]

# 優化器
# disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)

global_step = tf.train.get_or_create_global_step()  # 使用MonitoredTrainingSession,必須有

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d, var_list=d_vars,
                                                     global_step=global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g, var_list=g_vars,
                                                   global_step=gen_global_step)

##################################################################
#  5.訓練與測試
#  建立session,循環中使用run來運行兩個優化器
##################################################################

training_epochs = 3
display_step = 1

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
    total_batch = int(mnist.train.num_examples / batch_size)
    print("global_step.eval(session=sess)", global_step.eval(session=sess),
          int(global_step.eval(session=sess) / total_batch))
    for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
        avg_cost = 0.

        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            _, mis_batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step], feeds)
            l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)

        # 顯示訓練中的詳細信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)

    print("完成!")

    # 測試
    _, mis_batch_ys = mnist.train.next_batch(batch_size)
    print("result:",
          loss_d.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
                      session=sess)
          , loss_g.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
                        session=sess))

    # 根據圖片模擬生成圖片
    show_num = 10
    gensimple, inputx, inputy = sess.run(
        [genout, x, y], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
        a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))

    plt.draw()
    plt.show()

-----------------------------------------------------------------------------------------------------------------------------------------

運行結果:

在這裏插入圖片描述

在這裏插入圖片描述
使用GAN-cls技術同樣也實現了生成與標籤對應的樣本,而且整體代碼的運算要比ACGAN簡潔很多(絲毫沒覺得,專門算過時間,沒啥變化 =.=)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章