【調參06】如何通過設置及時停止避免過擬合



訓練神經網絡時,控制訓練的週期很關鍵。如果過早停止訓練,可能導致欠擬合;如果訓練週期過長,可能導致過擬合,從而導致泛化能力很差。一種解決方案是在訓練數據集上進行訓練,在驗證數據集的性能開始下降時停止訓練。這種簡單,有效且廣泛使用的訓練神經網絡的方法稱爲及時停止(Early Stopping)。

1. 如何使用及時停止

及時停止要求網絡配置爲處於受限狀態,這意味着網絡具有比問題所需容量更多的容量。

在訓練網絡時,使用比通常更多的訓練週期,以使網絡有足夠的週期擬合,然後設置及時停止使在訓練週期合適時停止訓練。使用及時停止有三個要素:

  • 要監視的指標;
  • 觸發停止條件;
  • 要保存的模型;

1.1 要監視的指標

訓練神經網絡時,通常會從數據集中拆分出一個子集(例如30%)作爲驗證數據集,用於在訓練過程中監視模型的性能。該驗證集不用於訓練模型。通常也可以使用驗證數據集上的損失作爲監視指標。

一般來說,在迴歸問題中,使用驗證集上的預測誤差作爲指標;在分類問題中,使用驗證集上的準確率作爲指標。

訓練數據集上模型的損失也將作爲訓練過程的一部分提供,還可以在訓練數據集上計算和監視其他指標。

在每個時期結束時在驗證集上評估模型的性能,這會增加訓練期間的額外計算成本。可以通過不那麼頻繁地評估模型(例如每2、5或10個訓練時期)來減少這種情況。


1.2 觸發停止條件

選擇了模型的監視指標之後,就要設置停止訓練的觸發器。

觸發器將使用監視的性能指標來決定何時停止訓練。這通常是模型在驗證集上的性能,例如驗證損失(val_loss)。

在最簡單的情況下,與先前訓練時期(例如,val_loss增加)相比,驗證數據集的性能下降後,訓練就會立即停止。

在實踐中可能需要更詳細的觸發器。這是因爲神經網絡的訓練是隨機的,並且可能夾雜很多噪聲。繪製驗證損失和驗證準確率曲線可以看出,模型的性能可能會多次上升和下降。也就是說,出現第一個過度擬合跡象時就停止訓練是不妥當的,因爲實際的驗證集上的誤差曲線存在多個局部極小值。

一些更詳細的觸發器可能包括:

  • 給定週期內,指標沒有變化;
  • 給定週期內,指標的絕對變化;
  • 給定週期內,指標的性能下降;
  • 給定週期內,指標的平均變化;

1.3 要保存的模型

在停止訓練時,該模型的泛化誤差比先前時期的模型大一些。

因此,需要考慮合適保存模型或者說如何保存性能最好的模型,亦即保存訓練過程中哪個模型的權重。這取決於爲停止訓練過程而選擇的觸發器。例如,如果觸發是性能從一個時期到下一個時期的降低,那麼將優先考慮模型在先前時期的權重。如果要求觸發器在固定時期內觀察到性能下降,則將首選觸發器週期開始時的模型。

一個常用的方法是:每當驗證集上的誤差改善時,就保存此時模型權重文件。當終止訓練時,得到的模型權重就是最佳模型的權重。


2. 及時停止使用技巧

2.1 適用範圍

幾乎所有的神經網絡都需要設置及時停止。


2.2 通過繪製曲線觀察

在使用及時停止之前,可能需要設置較長的訓練週期來讓模型進行擬合,並在訓練和驗證數據集上監視模型的性能。實時或長期運行結束後繪製模型的性能,通過觀察監視指標的變化情況,有助於選擇提前停止的觸發器。


2.3 監視指標選擇

損失是在訓練過程中進行監控並觸發提前停止的簡單指標。

問題在於,損失並不能總反映出最適合業務場景需求的模型。最好選擇一種性能指標進行監視,以最好地定義模型的性能。


2.4 訓練週期選擇

及時停止的一個問題是模型沒有利用所有可用的訓練數據。

這可能需要避免過擬合併在所有可能的數據上進行訓練,尤其是在訓練數據量非常有限的情況下。

推薦的方法是將訓練時期的數量視爲超參數,使用k-折交叉驗證對不同值的範圍進行網格搜索。可以固定訓練週期的數量,並在所有可用數據上擬合最終模型。

及時停止過程可以重複多次。可以記錄停止訓練的週期數。然後,在將最終模型擬合到所有可用的訓練數據上時,可以使用及時停止的所有週期數的平均值。

每次運行早期停止時,都可以使用不同的訓練集劃分爲訓練和驗證步驟來執行此過程。一種替代方法可能是使用驗證數據集及時停止,然後通過對所提供的驗證集進行進一步訓練來更新最終模型。


2.5 過擬合驗證

多次重複及時停止過程可能會導致模型過度擬合驗證數據集,這和過度擬合訓練數據集一樣容易。一種方法是在選擇了模型的所有其他超參數後才使用及時停止。

另一種策略可能是每次使用及時停止時,都將訓練數據集分爲不同的訓練集和驗證集。


3. TensorFlow API

3.1 拆分驗證集

在 tensorflow.keras 中,有兩種方式可以設置驗證集:

...
model.fit(train_X, train_y, validation_data=(val_x, val_y))
...
model.fit(train_X, train_y, validation_split=0.3)

3.2 設置監視指標

如果在 model.fit() API中設置了 validation_datavalidation_split 參數,則會返回驗證數據集上的損失,名稱爲 val_loss

可以在編譯模型時通過 model.compile 函數的 metrics 參數來指定它們。此參數使用Python列表傳入,例如 mse 表示均方誤差,precision 表示精度。常用監視指標:

...
model.compile(..., metrics=['accuracy'])

如果在訓練時監視其它指標,也可通過相同的名稱提供給 metrics 參數,如 val_accuracy 表示驗證集上的準確率。mse 表示訓練集上的均方誤差,val_mse 表示驗證集上的均方誤差。


3.3 EarlyStopping API

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

參數說明:

  • monitor:監視指標。
  • min_delta:監視指標的最小變化(有改進的定義);即絕對變化小於 min_delta,不視爲改進。
  • patience:沒有改進的時期數,之後訓練將停止。
  • verbose:在控制檯打印信息的級別;verbose設置爲1時,訓練停止時返回訓練週期數。
  • mode{"auto", "min", "max"} 中選其一。在 min 模式下,當監視的數量停止減少時,訓練將停止;在max 模式下,當監視的數量停止增加時,它將停止;在 auto 模式下,將根據監視數量的名稱自動推斷出變化方向;
  • baseline:監視數量的基準值。如果模型沒有超過基線的改善,則停止訓練。
  • restore_best_weights:是否使用該訓練週期中監視指標最佳模型所對應的權重。如果爲 False,則使用在訓練的最後一步獲得的模型權重。

3.4 ModelCheckPoint API

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)

用於以一定間隔保存模型或權重(在檢查點文件中),之後可以重新加載模型或權重以從保存的狀態繼續訓練。

  • filepath:字符串,保存模型文件的路徑。filepath可以包含命名的格式選項,這會傳遞給該API的方法 on_epoch_end 中。例如:如果filepath爲 weights.{epoch:02d}-{val_loss:.2f}.hdf5,則模型將以epoch和驗證損失作爲文件名保存。
  • monitor:監視指標。
  • verbose:在控制檯打印信息的級別,0或1。
  • save_best_only:如果爲 save_best_only=True,則最佳模型不會被覆蓋。如果filepath不包含格式設置選項,原名稱會被覆蓋。
  • save_weights_only:如果爲True,則僅保存模型的權重(model.save_weights(filepath)),否則保存完整的模型(model.save(filepath))。
  • save_freq'epoch'或整數。使用'epoch'時,回調函數會在每個時期後保存模型。使用整數時,回調將在許多批次結束時保存模型。請注意,如果保存未與時間段保持一致,則受監視的指標可能會不太可靠(它可能只反映1個批次,因爲每個epoch都會重置該指標)。默認爲’epoch’。

3.5 實例

# 定義模型
model = Sequential()
model.add(Dense(500, input_dim=2, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 編譯模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# 定義回調
callbacks_set = [EarlyStopping(monitor='val_loss', mode='min', verbose=1),
				ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)]
# 訓練模型
# 及時停止,同時保存驗證集上性能最好(驗證損失最低)的模型
history = model.fit(trainX, trainy, validation_data=(testX, testy), epochs=4000, verbose=0, 
					callbacks=callbacks_set)

tensorboard log 等等其它回調同理。


https://machinelearningmastery.com/early-stopping-to-avoid-overtraining-neural-network-models/
https://machinelearningmastery.com/how-to-stop-training-deep-neural-networks-at-the-right-time-using-early-stopping/
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping?hl=en
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint?hl=en

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章