深度學習| 通過蒸餾收斂一個更優模型部署

蒸餾收斂 

基於keras的知識蒸餾(Knowledge Distillation)-分類與迴歸

如果通過蒸餾收斂到一個更優的的部署模型

Knowledge Distillation  Introduction to Knowledge Distillation

知識提取是一種模型壓縮過程,其中對小(學生)模型進行訓練,以匹配預先訓練的大(教師)模型。通過最小化損失函數,將知識從教師模型轉移到學生身上,目的是匹配軟化的教師邏輯和基本事實

標籤。通過在softmax中應用“溫度”標度函數來軟化logits,有效地平滑了概率分佈,並揭示了教師學習到的課堂間關係。

Hinton et al. (2015)  

導入基礎包

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

構造Distiller類

自定義Distiller()類覆蓋Model方法train_step、test_step和compile()。爲使用蒸餾器,我們需要:

訓練有素的教師模型
要訓練的學生模型
關於學生預測和基本事實之間差異的學生損失函數
關於學生軟預測和教師軟標籤之間差異的蒸餾損失函數以及溫度
衡量學生體重和蒸餾損失的阿爾法因素
針對學生的優化器和(可選)評估績效的指標
在train_step方法中,我們執行教師和學生的前向傳遞,分別通過α和1-alpha對student_loss和distraction_loss進行加權來計算損失,並執行後向傳遞。注意:只有學生權重會更新,因此我們只計算學生權重的梯度。

在test_step方法中,我們在提供的數據集上評估學生模型。

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student
 
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.
        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature
 
    def train_step(self, data):
        # Unpack data 解析數據
        x, y = data
 
        # Forward pass of teacher 前向傳遞 
        teacher_predictions = self.teacher(x, training=False)
 
        with tf.GradientTape() as tape:
            # Forward pass of student 前向傳遞 
            student_predictions = self.student(x, training=True)
 
            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
 
            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )
            # Total loss: alpha*hard loss + (1-alpha)*soft loss
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
 
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
 
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
 
        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)
 
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results
 
    def test_step(self, data):
        # Unpack the data
        x, y = data
 
        # Compute predictions
        y_prediction = self.student(x, training=False)
 
        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
 
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
 
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

創建學生和教師模型

首先,創建一個教師模型和一個較小的學生模型。這兩個模型都是卷積神經網絡,使用Sequential()創建,也可以是其他Keras模型。

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)
 
# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)
 
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

準備數據集

用於訓練教師和提取教師的數據集是MNIST,並且該過程對於任何其他數據集都是等效的,例如CIFAR-10,只要選擇合適的模型。學生和老師都在訓練集上接受訓練,並在測試集上進行評估

# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
 
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
 
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

訓練教師模型

在知識提煉中,我們假設老師是經過訓練和固定的。因此,我們從以通常的方式在訓練集上訓練教師模型開始。

# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
 
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)

 

Epoch 1/5
1875/1875 [==============================] - 162s 86ms/step - loss: 0.1438 - sparse_categorical_accuracy: 0.9553
Epoch 2/5
1875/1875 [==============================] - 172s 92ms/step - loss: 0.0905 - sparse_categorical_accuracy: 0.9732
Epoch 3/5
1875/1875 [==============================] - 172s 92ms/step - loss: 0.0798 - sparse_categorical_accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 171s 91ms/step - loss: 0.0767 - sparse_categorical_accuracy: 0.9785
Epoch 5/5
1875/1875 [==============================] - 179s 95ms/step - loss: 0.0699 - sparse_categorical_accuracy: 0.9808
313/313 [==============================] - 6s 20ms/step - loss: 0.0894 - sparse_categorical_accuracy: 0.9763
[0.08935610204935074, 0.9763000011444092]

從老師蒸餾到學生

已經訓練了教師模型,只需要初始化Distiller(學生,教師)實例,用所需的損失、超參數和優化器對其進行compile(),並將教師提取給學生。從頭開始訓練學生進行比較; 

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#需要進行迴歸的時候可相應替換損失函數
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)
 
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
 
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

結果數據如下

Epoch 1/3
1875/1875 [==============================] - 37s 19ms/step - sparse_categorical_accuracy: 0.8863 - student_loss: 0.5352 - distillation_loss: 8.6172
Epoch 2/3
1875/1875 [==============================] - 37s 20ms/step - sparse_categorical_accuracy: 0.9647 - student_loss: 0.1374 - distillation_loss: 1.8981
Epoch 3/3
1875/1875 [==============================] - 38s 20ms/step - sparse_categorical_accuracy: 0.9718 - student_loss: 0.1047 - distillation_loss: 1.2105
313/313 [==============================] - 1s 2ms/step - sparse_categorical_accuracy: 0.9732 - student_loss: 0.1035

[0.9732000231742859, 0.0381324402987957]

從頭開始訓練學生進行比較 

#Train student model from scratch for comparison
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)

#student(train from scratch) accuracy: 0.9778
#0.9896 VS. 0.9778

 

Epoch 1/3
1875/1875 [==============================] - 7s 4ms/step - loss: 0.0680 - sparse_categorical_accuracy: 0.9791
Epoch 2/3
1875/1875 [==============================] - 7s 4ms/step - loss: 0.0597 - sparse_categorical_accuracy: 0.9819
Epoch 3/3
1875/1875 [==============================] - 7s 4ms/step - loss: 0.0545 - sparse_categorical_accuracy: 0.9829
313/313 [==============================] - 1s 2ms/step - loss: 0.0640 - sparse_categorical_accuracy: 0.9797
[0.06404071301221848, 0.9797000288963318]

如果老師接受了5個epoch的訓練,而學生在這個老師身上被提煉了3個epoch,那麼在這個例子中,與從頭開始訓練相同的學生模型相比,甚至與老師本身相比,應該會體驗到一種成績提升。

應該期望老師的準確率在97.6%左右,從頭開始訓練的學生的準確率應該在97.6%附近,蒸餾的學生應該在98.1%左右。

 

 

 
 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章