自編碼器重建 Fashion_mnist 數據集

自編碼器

from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, layers
import numpy as np
from matplotlib import pyplot as plt

加載數據集

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(buffer_size=512).batch(512)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(512)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

構建網絡

class AutoEncoder(keras.Model):

    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Encoders
        self.encoder = Sequential([
            layers.Dense(256, activation=tf.nn.relu),
            layers.Dense(128, activation=tf.nn.relu),
            layers.Dense(20)
        ])

        # Decoders
        self.decoder = Sequential([
            layers.Dense(128, activation=tf.nn.relu),
            layers.Dense(256, activation=tf.nn.relu),
            layers.Dense(784)
        ])

    # 前向計算
    def call(self, inputs, training=None):
        # [b, 784] => [b, 10]
        h = self.encoder(inputs)
        # [b, 10] => [b, 784]
        x_hat = self.decoder(h)

        return x_hat

網絡訓練

def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))

    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1

    new_im.save(name)
model = AutoEncoder()
model.build(input_shape=(None, 28 * 28))
model.summary()

optimizer = tf.optimizers.Adam(lr=1e-3)
Model: "auto_encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential (Sequential)      multiple                  236436    
_________________________________________________________________
sequential_1 (Sequential)    multiple                  237200    
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

開始訓練

for epoch in range(20):

    for step, x in enumerate(train_db):

        #[b, 28, 28] => [b, 784]
        x = tf.reshape(x, [-1, 28 * 28])
        # 構建梯度記錄器
        with tf.GradientTape() as tape:
            # 前向計算
            x_rec_logits = model(x)
            # 計算損失函數
            rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)
        # 自動求導
        grads = tape.gradient(rec_loss, model.trainable_variables)
        # 更新網絡
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # 打印訓練誤差
    print("epoch: ", epoch, "loss: ", float(rec_loss))

        
    # 從測試集採集圖片
    x = next(iter(test_db))
    logits = model(tf.reshape(x, [-1, 784]))
    # 講輸出值轉化爲像素值
    x_hat = tf.sigmoid(logits)
    # [b, 784] => [b, 28, 28] 恢復原始數據格式
    x_hat = tf.reshape(x_hat, [-1, 28, 28])

    # [b, 28, 28] => [2b, 28, 28]
    # 輸入的前 50 張+重建的前 50 張圖片合併
    x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
    # 恢復爲 0-255 的範圍
    x_concat = x_concat.numpy() * 255.
    # 轉換爲整型
    x_concat = x_concat.astype(np.uint8)
    save_images(x_concat, 'ae_images/mnist_%d.png'%epoch)

epoch:  0 loss:  0.1876431256532669
epoch:  1 loss:  0.14163847267627716
epoch:  2 loss:  0.12352141737937927
epoch:  3 loss:  0.11942803859710693
epoch:  4 loss:  0.11525192111730576
epoch:  5 loss:  0.10021436214447021
epoch:  6 loss:  0.10526927560567856
epoch:  7 loss:  0.10288294404745102
epoch:  8 loss:  0.10139968246221542
epoch:  9 loss:  0.10215207189321518
epoch:  10 loss:  0.0961870551109314
epoch:  11 loss:  0.091026671230793
epoch:  12 loss:  0.09655070304870605
epoch:  13 loss:  0.09417414665222168
epoch:  14 loss:  0.08978977054357529
epoch:  15 loss:  0.08931374549865723
epoch:  16 loss:  0.08951258659362793
epoch:  17 loss:  0.08937102556228638
epoch:  18 loss:  0.09456444531679153
epoch:  19 loss:  0.08556753396987915
def printImage(images):
    plt.figure(figsize=(10, 10))
    for i in range(20):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(images[i], cmap=plt.cm.binary)
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 講輸出值轉化爲像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢復原始數據格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])

# [b, 28, 28] => [2b, 28, 28]
# 輸入的前 50 張+重建的前 50 張圖片合併
x_concat = tf.concat([x[:10], x_hat[:10]], axis=0)
# 恢復爲 0-255 的範圍
x_concat = x_concat.numpy() * 255.
# 轉換爲整型
x_concat = x_concat.astype(np.uint8)
printImage(x_concat)
  • 上面 5 行是原始圖片, 下面 5 行是 重建後的圖片
    在這裏插入圖片描述

保存本地的圖片:

第一次 epoch
左邊 5 列是原圖片,右邊 5 列是經過重建後的。可以看到,此時還不是很清楚
在這裏插入圖片描述

第十次 epoch
在這裏插入圖片描述
第二十次 epoch
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章