GAN生成對抗網絡合集(六):GAN-cls –具有匹配感知的判別器(附代碼)

1 GAN-cls原理





2 代碼

直接修改上一篇 GAN生成對抗網絡合集(五):LSGan-最小二乘GAN(附代碼) 代碼,將其改成GAN-cls。

  1. 修改判別器D
# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y):  # 判別器函數 : x兩次卷積,再接兩次全連接; y代表輸入的樣本標籤
    reuse = len([t for t in tf.global_variables() if'discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):

        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu)  # 將y變爲與圖片一樣維度的映射
        y = tf.reshape(y, shape=[-1, 28, 28, 1])    # 將y統一成圖片格式

        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # 將二者連接到一起,統一處理
        x = tf.concat(axis=3, values=[x, y])  # x.shape = [-1, 28, 28, 2]

        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 輸入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的數據可以分別連接不同的輸出層產生不同的結果
        # 1維的輸出層產生判別結果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10維的輸出層產生分類結果 (樣本標籤)
        # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2維輸出層產生重構造的隱含維度信息
        # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc  # recog_cat, recog_cont
  1. 添加錯誤標籤輸入符,構建網絡結構


#  3.定義網絡模型 : 定義 參數/輸入/輸出/中間過程(經過G/D)的輸入輸出
batch_size = 10  # 獲取樣本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2  # 隱含信息變量的維度, 應節點爲z_con
rand_dim = 38  # 一般噪聲的維度, 應節點爲z_rand, 二者都是符合標準高斯分佈的隨機數。
n_input = 784  # 28 * 28

x = tf.placeholder(tf.float32, [None, n_input])  # x爲輸入真實圖片images
y = tf.placeholder(tf.int32, [None])  # y爲真實標籤labels
misy = tf.placeholder(tf.int32, [None])  # 錯誤標籤

# z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1)  # shape = (10, 28, 28)

# labels for discriminator
# y_real = tf.ones(batch_size)  # 真
# y_fake = tf.zeros(batch_size)  # 假

# 判別器D
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat([tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
# disc_real, class_real, _ = discriminator(x)
# disc_fake, class_fake, con_fake = discriminator(gen)
# pred_class = tf.argmax(class_fake, dimension=1)

disc_all = discriminator(xin, yin)
# 真實樣本與真實標籤、生成的圖像gen與真實標籤、真實樣本與錯誤標籤所對應的判別值
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)
  1. 修改loss值
# 判別器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))  # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))  # 0

loss_d = tf.reduce_sum(tf.square(disc_real-1) + (tf.square(disc_fake-0)+tf.square(disc_mis-0))/2) / 2

# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2

# categorical factor loss 分類因素損失
# loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
# loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
# loss_c = (loss_cf + loss_cr) / 2

# continuous factor loss 隱含信息變量的損失
# loss_con = tf.reduce_mean(tf.square(con_fake - z_con))
  1. 使用MonitoredTrainingSession創建sesson,開始訓練
#  5.訓練與測試
#  建立session,循環中使用run來運行兩個優化器

training_epochs = 3
display_step = 1

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
    total_batch = int(mnist.train.num_examples / batch_size)
    print("global_step.eval(session=sess)", global_step.eval(session=sess),
          int(global_step.eval(session=sess) / total_batch))
    for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
        avg_cost = 0.

        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            _, mis_batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step =[loss_d, train_disc, global_step], feeds)
            l_gen, _, l_g_step =[loss_g, train_gen, gen_global_step], feeds)

        # 顯示訓練中的詳細信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)




# !/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = '黎明'

#  1.引入頭文件並加載mnist數據
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim
import time
from timer import Timer
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/")  # ,one_hot=True)

tf.reset_default_graph()  # 用於清除默認圖形堆棧並重置全局默認圖形

#  2.定義生成器與判別器
def generator(x):  # 生成器函數 : 兩個全連接+兩個反捲積模擬樣本的生成,每一層都有BN(批量歸一化)處理
    reuse = len([t for t in tf.global_variables() if'generator')]) > 0  # 確認該變量作用域沒有變量
    # print (x.get_shape())
    with tf.variable_scope('generator', reuse=reuse):
        x = slim.fully_connected(x, 1024)
        # print(x)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = slim.fully_connected(x, 7 * 7 * 128)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = tf.reshape(x, [-1, 7, 7, 128])
        # print ('22', tf.tensor.get_shape())
        x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn=None)
        # print ('gen',x.get_shape())
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
        # print ('genz',z.get_shape())
    return z

def leaky_relu(x):
    return tf.where(tf.greater(x, 0), x, 0.01 * x)

# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y):  # 判別器函數 : x兩次卷積,再接兩次全連接; y代表輸入的樣本標籤
    reuse = len([t for t in tf.global_variables() if'discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):

        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu)  # 將y變爲與圖片一樣維度的映射
        y = tf.reshape(y, shape=[-1, 28, 28, 1])    # 將y統一成圖片格式

        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # 將二者連接到一起,統一處理
        x = tf.concat(axis=3, values=[x, y])  # x.shape = [-1, 28, 28, 2]

        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 輸入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的數據可以分別連接不同的輸出層產生不同的結果
        # 1維的輸出層產生判別結果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10維的輸出層產生分類結果 (樣本標籤)
        # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2維輸出層產生重構造的隱含維度信息
        # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc  # recog_cat, recog_cont

#  3.定義網絡模型 : 定義 參數/輸入/輸出/中間過程(經過G/D)的輸入輸出
batch_size = 10  # 獲取樣本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2  # 隱含信息變量的維度, 應節點爲z_con
rand_dim = 38  # 一般噪聲的維度, 應節點爲z_rand, 二者都是符合標準高斯分佈的隨機數。
n_input = 784  # 28 * 28

x = tf.placeholder(tf.float32, [None, n_input])  # x爲輸入真實圖片images
y = tf.placeholder(tf.int32, [None])  # y爲真實標籤labels
misy = tf.placeholder(tf.int32, [None])  # 錯誤標籤

# z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1)  # shape = (10, 28, 28)

# labels for discriminator
# y_real = tf.ones(batch_size)  # 真
# y_fake = tf.zeros(batch_size)  # 假

# 判別器D
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat([tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
# disc_real, class_real, _ = discriminator(x)
# disc_fake, class_fake, con_fake = discriminator(gen)
# pred_class = tf.argmax(class_fake, dimension=1)

disc_all = discriminator(xin, yin)
# 真實樣本與真實標籤、生成的圖像gen與真實標籤、真實樣本與錯誤標籤所對應的判別值
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)

#  4.定義損失函數和優化器
# 判別器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))  # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))  # 0

loss_d = tf.reduce_sum(tf.square(disc_real-1) + (tf.square(disc_fake-0)+tf.square(disc_mis-0))/2) / 2

# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2

# categorical factor loss 分類因素損失
# loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
# loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
# loss_c = (loss_cf + loss_cr) / 2

# continuous factor loss 隱含信息變量的損失
# loss_con = tf.reduce_mean(tf.square(con_fake - z_con))

# 獲得各個網絡中各自的訓練參數列表
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in]
g_vars = [var for var in t_vars if 'generator' in]

# 優化器
# disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)

global_step = tf.train.get_or_create_global_step()  # 使用MonitoredTrainingSession,必須有

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d, var_list=d_vars,
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g, var_list=g_vars,

#  5.訓練與測試
#  建立session,循環中使用run來運行兩個優化器

training_epochs = 3
display_step = 1

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
    total_batch = int(mnist.train.num_examples / batch_size)
    print("global_step.eval(session=sess)", global_step.eval(session=sess),
          int(global_step.eval(session=sess) / total_batch))
    for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
        avg_cost = 0.

        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            _, mis_batch_ys = mnist.train.next_batch(batch_size)  # 取數據
            feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step =[loss_d, train_disc, global_step], feeds)
            l_gen, _, l_g_step =[loss_g, train_gen, gen_global_step], feeds)

        # 顯示訓練中的詳細信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)


    # 測試
    _, mis_batch_ys = mnist.train.next_batch(batch_size)
          loss_d.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
          , loss_g.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},

    # 根據圖片模擬生成圖片
    show_num = 10
    gensimple, inputx, inputy =
        [genout, x, y], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
        a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))





使用GAN-cls技術同樣也實現了生成與標籤對應的樣本,而且整體代碼的運算要比ACGAN簡潔很多(絲毫沒覺得,專門算過時間,沒啥變化 =.=)。

