快速上手生成對抗生成網絡生成手寫數字集(直接上代碼以及詳細註釋,親測可用)

GAN的原理其實很簡單,就是生成網絡G, 和判別網絡D的對抗過程, 生成網絡努力使得生成的虛假物品更加真實,而判別網絡努力分別出哪些是G生成的,哪些是真實的,在這樣一個對抗的過程中兩個網絡的能力不斷得到提升。最終達到一個相對平衡的結果:理想狀態下, G學習到真實數據的分佈,而D無法分別出真實數據和生成數據, 即D(真) = D(假) = 0.5.
這裏插入幾張模型生成的圖片,從左到右分別是隨機生成的圖片,100輪之後的圖片,2000輪之後的圖片,8000輪之後的圖片。 代碼雖然有100多行,但註釋大概佔了一般左右。一起交流,一起進步!
未經過訓練的結果訓練100輪的結果訓練2000多輪的結果訓練4000多次之後得到的結果

import tensorflow as tf
from matplotlib import gridspec
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import cv2
"""
定義各個超參數,包括輸入圖片的大小, 隱藏層大小, 學習率,batch_size, 迭代次數等
"""
# D 代表判別網絡, 定義判別網絡的參數
D_input_size = 784
D_H_layer1 = 200
D_output_size = 1

# G 代表生成網絡,定義生成網絡的參數
G_input_size = 100
G_H_layer1 = 300
G_output_size = 784

Learning_rate = 1e-3
iterations = 50000
batch_size = 16

"""
:param D_input_size: 判別網絡的第一層輸入(樣本圖片(生成圖片)的大小),需要輸入給判別網絡做判斷的圖片的大小
:param D_H_layer1: 判別網絡的第一個隱藏層的大小,暫時設置成默認值200
:param D_output: 判別網絡的第輸出層的大小, 因爲輸出是一個概率值所以只有一個神經元
        所以判別網絡的模型是 784--》200——》1
"""
D_W1 = tf.Variable(tf.truncated_normal([D_input_size, D_H_layer1], stddev=0.1), name="D_W1",dtype=tf.float32)
D_b1 = tf.Variable(tf.zeros([D_H_layer1]), name="D_b1", dtype=tf.float32)
D_W2 = tf.Variable(tf.truncated_normal([D_H_layer1, D_output_size], stddev=0.1), name="D_W2",dtype=tf.float32)
D_b2 = tf.Variable(tf.zeros([D_output_size]), name="D_b2", dtype=tf.float32)


"""
:param G_input_size:生成網絡的輸入,默認設置成100,是一些隨機噪聲,從這些噪聲中 逐步的生成理想的圖片
:param G_H_layer1: 生成網絡的第一層隱藏層的大小,默認設置成300
:param G_output_size: 生成網絡的輸出層大小,即要生成的圖片大小,顯然應該是與生成網絡的輸入一致
所以生成網絡的模型是100——》300——》784
"""
G_W1 = tf.Variable(tf.truncated_normal([G_input_size, G_H_layer1], stddev=0.1), name="G_W1",dtype=tf.float32)
G_b1 = tf.Variable(tf.zeros([G_H_layer1]), name="G_b1", dtype=tf.float32)
G_W2 = tf.Variable(tf.truncated_normal([G_H_layer1, G_output_size], stddev=0.1), name="G_W2",dtype=tf.float32)
G_b2 = tf.Variable(tf.zeros([G_output_size]), name="G_b2", dtype=tf.float32)

# 用列表保存,後面構建網絡,和更新參數的時候會用到
G_variables = [G_W1, G_b1, G_W2, G_b2]
D_variables = [D_W1, D_b1, D_W2, D_b2]

# 判別網絡
def discriminator(D_input):
    D_A1 = tf.nn.relu(tf.matmul(D_input, D_W1) + D_b1)
    D_output = tf.nn.sigmoid(tf.matmul(D_A1, D_W2) + D_b2)
    return D_output

# 生成網絡
def generator(G_input):
    G_A1 = tf.nn.relu(tf.matmul(G_input, G_W1) + G_b1)
    G_output = tf.sigmoid(tf.matmul(G_A1, G_W2) + G_b2)
    return G_output

# 使用placeholder來定義生成網絡的輸入, G_image 表示生成的圖片, real_image 表示真實的圖片
G_input = tf.placeholder(tf.float32, shape=[None, 100])
G_image = generator(G_input)
real_image = tf.placeholder(tf.float32, shape=[None, 784])
# 判別網絡對生成圖片的判別概率
D_fake = discriminator(G_image)
# 判別網絡對真實圖片的判別概率
D_real = discriminator(real_image)


# 至此就到了最重要的一步了! 定義生成網絡和判別網絡的損失函數
# 首先是判別網絡的損失函數:判別網絡的目的是能很好的區分真實圖片和生成圖片,
# D_real 是判別網絡對真實圖片的判別概率,當然是越大越好,D_fake是判別網絡對
# 生成圖片的判別概率,判別器越好則越能是D_fake 變小,即(1.0 - D_fake)越大
# 越好,由此可得我們的目的是最小化目標函數D_loss
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# 其次來看生成網絡的損失函數,生成的目的是儘可能的騙過判別網絡,使得生成的圖片和
# 樣本中的片面儘量相同,故對於生成網絡來說D_fake 越大越能體現其優越。即G_loss越小
# 越好
G_loss = -tf.reduce_mean(tf.log(D_fake))

# 使用反向傳播算法來更新參數,注意var_list表示需要更新的參數列表
D_train = tf.train.AdamOptimizer(Learning_rate).minimize(D_loss, var_list=D_variables)
G_train = tf.train.AdamOptimizer(Learning_rate).minimize(G_loss, var_list=G_variables)
# 讀取mnist數據
mnist = input_data.read_data_sets("MNIST_DATA", one_hot=True)

# 畫圖函數
def plot(generate_pictures):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(generate_pictures):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

# 打開一個會話

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 構建生成圖片的路徑, 若路徑不存在則新建一個
    if not os.path.exists('output_pictures/'):
        os.makedirs('output_pictures/')

    for i in range(iterations):
        # 從mnist 中讀取真實的圖片
        real_picture, _ = mnist.train.next_batch(batch_size)
        # 運行判別網絡和生成網絡, 注意要feed兩個placeholder
        current_generate_loss, _ = sess.run([G_loss, G_train], feed_dict={
            G_input: np.random.uniform(-1., 1., size=[batch_size, 100])})
        current_discriminator_loss, _ = sess.run([D_loss, D_train],
                                                 feed_dict={G_input: np.random.uniform(-1., 1.,
                                                                                       size=[batch_size,
                                                                                             100]),
                                                            real_image: real_picture})
        # 每隔100輪保存一下生成器生成的圖片
        if i % 100 == 0:
            # np.random.uniform()生成大小爲【self.batch_size, 100】的均勻分佈,作爲生成網絡的輸入
            generate_pictures = sess.run(G_image,
                                         feed_dict={G_input: np.random.uniform(-0.5, 0.5,
                                                                               size=[batch_size, 100])})
            # generate_pictures爲16 x 784 的矩陣,每一行表示一張圖片,從中隨機抽取一張保存下來
            import matplotlib.pyplot as plt

            fig = plot(generate_pictures)
            plt.savefig('output_pictures/image{}.jpg'.format(str(i // 100)), bbox_inches='tight')
            plt.close(fig)
            # 顯示單張圖
            single_picture = generate_pictures[0]
            # 這裏因爲生成圖片的最後一層採用的是signoid函數,輸出值爲0-1,而像素值是0-255,所以乘以255
            single_picture = np.reshape(single_picture, (28, 28)) * 255
            cv2.imwrite("output_pictures/A{}.jpg".format(str(i // 100)), single_picture)


            # 輸入連個網絡當前的損失
            print("Iterations: " + str(i) + ",the D_loss is %.4f,  and the G_loss is %.4f" % (
                current_discriminator_loss, current_generate_loss))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章