自編碼器
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, layers
import numpy as np
from matplotlib import pyplot as plt
加載數據集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(buffer_size=512).batch(512)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(512)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
構建網絡
class AutoEncoder(keras.Model):
def __init__(self):
super(AutoEncoder, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(20)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
# 前向計算
def call(self, inputs, training=None):
# [b, 784] => [b, 10]
h = self.encoder(inputs)
# [b, 10] => [b, 784]
x_hat = self.decoder(h)
return x_hat
網絡訓練
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
model = AutoEncoder()
model.build(input_shape=(None, 28 * 28))
model.summary()
optimizer = tf.optimizers.Adam(lr=1e-3)
Model: "auto_encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
開始訓練
for epoch in range(20):
for step, x in enumerate(train_db):
#[b, 28, 28] => [b, 784]
x = tf.reshape(x, [-1, 28 * 28])
# 構建梯度記錄器
with tf.GradientTape() as tape:
# 前向計算
x_rec_logits = model(x)
# 計算損失函數
rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
rec_loss = tf.reduce_mean(rec_loss)
# 自動求導
grads = tape.gradient(rec_loss, model.trainable_variables)
# 更新網絡
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# 打印訓練誤差
print("epoch: ", epoch, "loss: ", float(rec_loss))
# 從測試集採集圖片
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 講輸出值轉化爲像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢復原始數據格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b, 28, 28] => [2b, 28, 28]
# 輸入的前 50 張+重建的前 50 張圖片合併
x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
# 恢復爲 0-255 的範圍
x_concat = x_concat.numpy() * 255.
# 轉換爲整型
x_concat = x_concat.astype(np.uint8)
save_images(x_concat, 'ae_images/mnist_%d.png'%epoch)
epoch: 0 loss: 0.1876431256532669
epoch: 1 loss: 0.14163847267627716
epoch: 2 loss: 0.12352141737937927
epoch: 3 loss: 0.11942803859710693
epoch: 4 loss: 0.11525192111730576
epoch: 5 loss: 0.10021436214447021
epoch: 6 loss: 0.10526927560567856
epoch: 7 loss: 0.10288294404745102
epoch: 8 loss: 0.10139968246221542
epoch: 9 loss: 0.10215207189321518
epoch: 10 loss: 0.0961870551109314
epoch: 11 loss: 0.091026671230793
epoch: 12 loss: 0.09655070304870605
epoch: 13 loss: 0.09417414665222168
epoch: 14 loss: 0.08978977054357529
epoch: 15 loss: 0.08931374549865723
epoch: 16 loss: 0.08951258659362793
epoch: 17 loss: 0.08937102556228638
epoch: 18 loss: 0.09456444531679153
epoch: 19 loss: 0.08556753396987915
def printImage(images):
plt.figure(figsize=(10, 10))
for i in range(20):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(images[i], cmap=plt.cm.binary)
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 講輸出值轉化爲像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢復原始數據格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b, 28, 28] => [2b, 28, 28]
# 輸入的前 50 張+重建的前 50 張圖片合併
x_concat = tf.concat([x[:10], x_hat[:10]], axis=0)
# 恢復爲 0-255 的範圍
x_concat = x_concat.numpy() * 255.
# 轉換爲整型
x_concat = x_concat.astype(np.uint8)
printImage(x_concat)
- 上面 5 行是原始圖片, 下面 5 行是 重建後的圖片
保存本地的圖片:
第一次 epoch
左邊 5 列是原圖片,右邊 5 列是經過重建後的。可以看到,此時還不是很清楚
第十次 epoch
第二十次 epoch