一個值得深思的問題?爲什麼驗證集的loss會小於訓練集的loss

在本教程中,您將學習在訓練自己的自定義深度神經網絡時,驗證損失可能低於訓練損失的三個主要原因。

我的驗證損失低於訓練損失!

怎麼可能呢?

  • 我是否意外地將訓練和驗證loss繪圖的標籤切換了? 潛在地。 我沒有像matplotlib這樣的繪圖庫,因此將丟失日誌通過管道傳輸到CSV文件,然後在Excel中進行繪圖。 絕對容易發生人爲錯誤。
  • 我的代碼中有錯誤嗎? 幾乎可以確定。 我同時在自學Java和機器學習-該代碼中肯定存在某種錯誤。
  • 我只是因爲太疲倦而無法理解我的大腦嗎? 也很有可能。 我一生中的睡眠時間不多,很容易錯過一些明顯的事情。

但是,事實證明,上述情況都不是——我的驗證損失確實比我的訓練損失低。

要了解您的驗證loss可能低於訓練loss的三個主要原因,請繼續閱讀!

 

爲什麼我的驗證loss低於訓練loss?

在本教程的第一部分中,我們將討論神經網絡中“loss”的概念,包括loss代表什麼以及我們爲什麼對其進行測量。

在此,我們將實現一個基礎的CNN和訓練腳本,然後使用新近實現的CNN進行一些實驗(這將使我們的驗證損失低於我們的訓練損失)。

根據我們的結果,我將解釋您的驗證loss可能低於訓練loss的三個主要原因。

 

訓練神經網絡時的“loss”是什麼?

[1] 機器/深度學習的中的“loss”是什麼? 爲什麼我的驗證loss低於訓練loss?

在最基本的層次上,loss函數可量化給定預測變量對數據集中輸入數據點進行分類的“好”或“壞”程度。

loss越小,分類器在建模輸入數據和輸出目標之間的關係方面的工作就越好。

就是說,在某種程度上我們可以過度擬合我們的模型-通過過於緊密地建模訓練數據(modeling the training data too closely),我們的模型將失去泛化的能力。

因此,我們尋求:

  1. 儘可能降低loss,從而提高模型精度。
  2. 儘可能快地這樣子做,並減少超參數的更新/實驗次數。
  3. 所有這些都沒有過度擬合我們的網絡,也沒有將訓練數據建模得過於緊密。。

這是一種平衡,我們選擇loss函數和模型優化器會極大地影響最終模型的質量,準確性和通用性。

典型的損失函數(也稱爲“目標函數”或“評分函數”)包括:

  • Binary cross-entropy
  • Categorical cross-entropy
  • Sparse categorical cross-entropy
  • Mean Squared Error (MSE)
  • Mean Absolute Error (MAE)
  • Standard Hinge
  • Squared Hinge

對loss函數的全面回顧不在本文的範圍內,但就目前而言,只需瞭解對於大多數任務:

  • loss衡量你的模型的“好(goodness)”
  • loss越小越好
  • 但你要小心別過擬合

要了解在訓練自己的自定義神經網絡時loss函數的作用,請確保:

 

文件結構

從那裏,通過tree命令檢查項目/目錄結構:

  $ tree --dirsfirst
  .
  ├── pyimagesearch
  │   ├── __init__.py
  │   └── minivggnet.py
  ├── fashion_mnist.py
  ├── plot_shift.py
  └── training.pickle
  ​
  1 directory, 5 files

今天我們將使用一個稱爲MiniVGGNet的更小版本的vggnet。pyimagesearch模塊包括這個CNN。

我們的fashion_mnist.py腳本在fashion MNIST數據集上訓練MiniVGGNet。我在之前的一篇博文中寫過關於在時尚mnist上訓練MiniVGGNet,所以今天我們不會詳細討論。

https://www.pyimagesearch.com/2019/02/11/fashion-mnist-with-keras-and-deep-learning/

今天的訓練腳本將生成一個training.pickle文件,其中包含訓練精度/loss歷史記錄。在下面的原因部分中,我們將使用plot_shift.py將訓練loss圖移動半個epoch,以證明當驗證loss低於訓練loss時,測量loss的時間起作用。現在讓我們深入探討三個原因來回答這個問題:“爲什麼我的驗證loss比訓練loss低?“。

 

原因1:在訓練中應用正則化,但在驗證/測試中未應用正則化

 

[2] Aurélien在他的Twitter提要上回答了一個問題:“大家都想知道爲什麼驗證loss>訓練loss嗎?”。 第一個原因是在訓練過程中應用了正則化,但在驗證/測試過程中未進行正則化。

在訓練深度神經網絡時,我們經常應用正則化來幫助我們的模型:

  1. 獲得更高的驗證/測試精度
  2. 理想情況下,爲了更好地泛化驗證和測試集之外的數據

正則化方法通常會犧牲訓練準確性來提高驗證/測試準確性——在某些情況下,可能導致您的驗證loss低於訓練loss。

其次,請記住,在驗證/測試時不應用諸如dropout之類的正則化方法。

作爲的Aurelien顯示在圖2中,原因驗證loss應正則化(例如,在驗證/測試時應用dropout)可以讓你的訓練/驗證loss曲線看起來更相似。

 

原因2:訓練loss是在每個epoch測量的,而驗證loss是在每個epoch後測量的

[3] 驗證loss的原因2有時小於訓練損失,這與進行測量的時間有關

您可能會看到驗證loss低於訓練loss的第二個原因是由於如何測量和報告loss值:

  1. 訓練loss在每個epoch過程中測量的
  2. 而驗證loss是在每個epoch後測量的

在整個epoch內,您的訓練loss將不斷得到報告;但是,僅在當前訓練epoch完成後,才根據驗證集計算驗證指標。

這意味着,平均而言,訓練loss要提前半個epoch來衡量。

如果您將訓練loss向左移動半個epoch,您會發現訓練和驗證loss值之間的差距要小得多。

有關此行爲的示例,請閱讀以下部分。

 

執行我們的訓練腳本

我們將實現一個簡單的Python腳本,以在Fashion MNIST數據集上訓練類似於VGG的小型網絡(稱爲MiniVGGNet)。在訓練期間,我們會將訓練和驗證loss保存到磁盤中。然後,我們將創建一個單獨的Python腳本,以比較未變動和變動後的loss圖。

讓我們開始執行loss腳本:

  # import the necessary packages
  from pyimagesearch.minivggnet import MiniVGGNet
  from sklearn.metrics import classification_report
  from tensorflow.keras.optimizers import SGD
  from tensorflow.keras.datasets import fashion_mnist
  from tensorflow.keras.utils import to_categorical
  import argparse
  import pickle
  ​
  # construct the argument parser and parse the arguments
  ap = argparse.ArgumentParser()
  ap.add_argument("-i", "--history", required=True,
      help="path to output training history file")
  args = vars(ap.parse_args())

第2-8行導入了我們所需的包,模塊,類和函數。 即,我們導入MiniVGGNet(我們的CNN),fashion_mnist(我們的數據集)和pickle(確保可以序列化我們的訓練歷史以使用單獨的腳本來處理繪圖)。

命令行參數--history指向單獨的.pickle文件,該文件將很快包含我們的訓練歷史記錄(第11-14行)。

然後,我們初始化一些超參數,即我們要訓練的epoch數,初始學習率和批量大小:

  # initialize the number of epochs to train for, base learning rate,
  # and batch size
  NUM_EPOCHS = 25
  INIT_LR = 1e-2
  BS = 32

然後,我們繼續加載和預處理我們的Fashion MNIST數據:

  # grab the Fashion MNIST dataset (if this is your first time running
  # this the dataset will be automatically downloaded)
  print("[INFO] loading Fashion MNIST...")
  ((trainX, trainY), (testX, testY)) = fashion_mnist.load_data()
   
  # we are using "channels last" ordering, so the design matrix shape
  # should be: num_samples x rows x columns x depth
  trainX = trainX.reshape((trainX.shape[0], 28, 28, 1))
  testX = testX.reshape((testX.shape[0], 28, 28, 1))
   
  # scale data to the range of [0, 1]
  trainX = trainX.astype("float32") / 255.0
  testX = testX.astype("float32") / 255.0
   
  # one-hot encode the training and testing labels
  trainY = to_categorical(trainY, 10)
  testY = to_categorical(testY, 10)
   
  # initialize the label names
  labelNames = ["top", "trouser", "pullover", "dress", "coat",
      "sandal", "shirt", "sneaker", "bag", "ankle boot"]

第3-13行加載並預處理訓練/驗證數據。

第16和17行將我們的類別標籤二值化,而第20和21行則列出了人類可讀的類別標籤名稱,以供日後分類報告之用。

從這裏,我們擁有編譯和訓練Fashion MNIST數據上的MiniVGGNet模型所需的一切:

  # initialize the optimizer and model
  print("[INFO] compiling model...")
  opt = SGD(lr=INIT_LR, momentum=0.9, decay=INIT_LR / NUM_EPOCHS)
  model = MiniVGGNet.build(width=28, height=28, depth=1, classes=10)
  model.compile(loss="categorical_crossentropy", optimizer=opt,
      metrics=["accuracy"])
  ​
  # train the network
  print("[INFO] training model...")
  H = model.fit(trainX, trainY,
      validation_data=(testX, testY),
       batch_size=BS, epochs=NUM_EPOCHS)

第3-6行初始化並編譯MiniVGGNet模型。

然後,第10-12行擬合/訓練模型。

從這裏我們將評估我們的模型並序列化我們的訓練歷史:

  # make predictions on the test set and show a nicely formatted
  # classification report
  preds = model.predict(testX)
  print("[INFO] evaluating network...")
  print(classification_report(testY.argmax(axis=1), preds.argmax(axis=1),
      target_names=labelNames))
  ​
  # serialize the training history to disk
  print("[INFO] serializing training history...")
  f = open(args["history"], "wb")
  f.write(pickle.dumps(H.history))
  f.close()

第3-6行對測試集進行預測,並將分類報告打印到終端。

10-12行將我們的訓練準確性/損失歷史序列化爲.pickle文件。 我們將在單獨的Python腳本中使用訓練歷史記錄來繪製損耗曲線,包括一個顯示二分之一epoch偏移的圖。

從那裏打開一個終端,然後執行以下命令:

  $ python fashion_mnist.py --history training.pickle
  [INFO] loading Fashion MNIST...
  [INFO] compiling model...
  [INFO] training model...
  Train on 60000 samples, validate on 10000 samples   
  Epoch 1/25
  60000/60000 [==============================] - 200s 3ms/sample - loss: 0.5433 - accuracy: 0.8181 - val_loss: 0.3281 - val_accuracy: 0.8815
  Epoch 2/25
  60000/60000 [==============================] - 194s 3ms/sample - loss: 0.3396 - accuracy: 0.8780 - val_loss: 0.2726 - val_accuracy: 0.9006
  Epoch 3/25
  60000/60000 [==============================] - 193s 3ms/sample - loss: 0.2941 - accuracy: 0.8943 - val_loss: 0.2722 - val_accuracy: 0.8970
  Epoch 4/25
  60000/60000 [==============================] - 193s 3ms/sample - loss: 0.2717 - accuracy: 0.9017 - val_loss: 0.2334 - val_accuracy: 0.9144
  Epoch 5/25
  60000/60000 [==============================] - 194s 3ms/sample - loss: 0.2534 - accuracy: 0.9086 - val_loss: 0.2245 - val_accuracy: 0.9194
  ...
  Epoch 21/25
  60000/60000 [==============================] - 195s 3ms/sample - loss: 0.1797 - accuracy: 0.9340 - val_loss: 0.1879 - val_accuracy: 0.9324
  Epoch 22/25
  60000/60000 [==============================] - 194s 3ms/sample - loss: 0.1814 - accuracy: 0.9342 - val_loss: 0.1901 - val_accuracy: 0.9313
  Epoch 23/25
  60000/60000 [==============================] - 193s 3ms/sample - loss: 0.1766 - accuracy: 0.9351 - val_loss: 0.1866 - val_accuracy: 0.9320
  Epoch 24/25
  60000/60000 [==============================] - 193s 3ms/sample - loss: 0.1770 - accuracy: 0.9347 - val_loss: 0.1845 - val_accuracy: 0.9337
  Epoch 25/25
  60000/60000 [==============================] - 194s 3ms/sample - loss: 0.1734 - accuracy: 0.9372 - val_loss: 0.1871 - val_accuracy: 0.9312
  [INFO] evaluating network...
                precision    recall  f1-score   support
  ​
           top       0.87      0.91      0.89      1000
       trouser       1.00      0.99      0.99      1000
      pullover       0.91      0.91      0.91      1000
         dress       0.93      0.93      0.93      1000
          coat       0.87      0.93      0.90      1000
        sandal       0.98      0.98      0.98      1000
         shirt       0.83      0.74      0.78      1000
       sneaker       0.95      0.98      0.97      1000
           bag       0.99      0.99      0.99      1000
    ankle boot       0.99      0.95      0.97      1000
  ​
      accuracy                           0.93     10000
     macro avg       0.93      0.93      0.93     10000
  weighted avg       0.93      0.93      0.93     10000
  ​
  [INFO] serializing training history...

檢查工作目錄的內容,您應該有一個名爲training.pickle的文件-該文件包含我們的訓練歷史日誌。

  $ ls *.pickle
  training.pickle

在下一節中,我們將學習如何繪製這些值並將訓練信息向左移動半個epoch,從而使我們的訓練/驗證loss曲線看起來更加相似。

 

平移我們的訓練loss值

我們的plot_shift.py腳本用於繪製來自fashion_mnist.py的訓練歷史記錄。 使用此腳本,我們可以研究將訓練損失向左移動半個世紀如何使我們的訓練/驗證圖看起來更相似。

打開plot_shift.py文件並插入以下代碼:

  # import the necessary packages
  import matplotlib.pyplot as plt
  import numpy as np
  import argparse
  import pickle
   
  # construct the argument parser and parse the arguments
  ap = argparse.ArgumentParser()
  ap.add_argument("-i", "--input", required=True,
      help="path to input training history file")
  args = vars(ap.parse_args())

第2-5行導入matplotlib(用於繪製),NumPy(用於簡單的數組創建操作),argparse(命令行參數)和pickle(加載我們的序列化訓練歷史記錄)。

第8-11行解析--input命令行參數,該參數指向磁盤上的.pickle訓練歷史記錄文件。

讓我們繼續加載數據並初始化繪圖:

  # load the training history
  H = pickle.loads(open(args["input"], "rb").read())
   
  # determine the total number of epochs used for training, then
  # initialize the figure
  epochs = np.arange(0, len(H["loss"]))
  plt.style.use("ggplot")
  (fig, axs) = plt.subplots(2, 1)

第2行使用--input命令行參數加載序列化的訓練歷史記錄.pickle文件。

第6行爲我們的x軸騰出了空間,該空間從零到訓練歷史中的epoch數。

第7行和第8行將我們的繪圖圖設置爲同一圖像中的兩個堆疊繪圖:

  • top plot將按原樣包含loss曲線。
  • 另一方面,bottom plot將包括訓練loss(但不包括驗證loss)的偏移。 訓練loss將按照Aurélien的推文向左移動半個epoch。 然後,我們將能夠觀察繪圖線的排列是否更加緊密。

讓我們生成top plot

  # plot the *unshifted* training and validation loss
  plt.style.use("ggplot")
  axs[0].plot(epochs, H["loss"], label="train_loss")
  axs[0].plot(epochs, H["val_loss"], label="val_loss")
  axs[0].set_title("Unshifted Loss Plot")
  axs[0].set_xlabel("Epoch #")
  axs[0].set_ylabel("Loss")
  axs[0].legend()

然後繪製bottom plot

  # plot the *shifted* training and validation loss
  axs[1].plot(epochs - 0.5, H["loss"], label="train_loss")
  axs[1].plot(epochs, H["val_loss"], label="val_loss")
  axs[1].set_title("Shifted Loss Plot")
  axs[1].set_xlabel("Epoch #")
  axs[1].set_ylabel("Loss")
  axs[1].legend()
   
  # show the plots
  plt.tight_layout()
  plt.show()

請注意,在第2行上,訓練損失向左移動了0.5個epoch,即本例的核心。

現在,讓我們分析我們的訓練/驗證圖。

打開一個終端並執行以下命令:

  python plot_shift.py --input training.pickle

[4] 將訓練損失圖向左移動1/2個epoch,可以得到更多類似的圖。 顯然,測量時間回答了一個問題:“爲什麼我的驗證loss低於訓練loss?”。

如您所見,將訓練loss值向左(底部)移動一個半個epoch,使訓練/驗證曲線與未移動(頂部)圖更加相似。

 

原因#3:驗證集可能比訓練集更容易(否則可能會泄漏(leaks))

[5] 考慮如何獲取/生成驗證集。 常見的錯誤可能導致驗證loss少於訓練loss。

驗證loss低於訓練loss的最終最常見原因是由於數據本身分佈的問題。

考慮如何獲取驗證集:

  • 您可以保證驗證集是從與訓練集相同的分佈中採樣的嗎?
  • 您確定驗證示例與您的訓練圖像一樣具有挑戰性嗎?
  • 您是否可以確保沒有“數據泄漏”(即訓練樣本與驗證/測試樣本意外混入)?
  • 您是否確信自己的代碼正確創建了訓練集,驗證集和測試集?

每位深度學習從業者在其職業中都至少犯過一次以上錯誤。

是的,它確實會令人尷尬-但這很重要-確實會發生,所以現在就花點時間研究您的代碼。

BONUS: Are you training hard enough?

[6] 如果您想知道爲什麼驗證損失低於訓練loss,也許您沒有“足夠努力地訓練”。

Aurélien在推文中沒有提及的一個方面是“足夠努力地訓練(training hard enough)”的概念。

在訓練深度神經網絡時,我們最大的擔心幾乎總是過擬合——爲了避免過擬合,我們引入了正則化技術(在上面的原因1中進行了討論)。我們用以下形式應用正則化:

  • Dropout
  • L2權重衰減
  • 減少模型容量(即更淺的模型)

我們的學習率也趨於保守一些,以確保我們的模型不會在虧損形勢下超越虧損較低的領域。

一切都很好,但是有時候我們最終會過度規範我們的模型 (over-regularizing our models)

如果您經歷了驗證loss低於上述詳細說明的訓練loss的所有三個原因,則可能是您的模型over-regularized了。通過以下方法開始放寬正則化約束:

  • 降低L2權重衰減強度。
  • 減少申請的dropout數量。
  • 增加模型容量(即,使其更深)。

您還應該嘗試以更高的學習率進行訓練,因爲您可能對此過於保守。

 

總結

今天的教程深受作者AurélienGeron的以下推文啓發。

在線程中,Aurélien簡潔明瞭地解釋了訓練深度神經網絡時驗證損失可能低於訓練損失的三個原因:

  1. 原因1:在訓練期間應用正則化,但在驗證/測試期間未進行正則化。如果在驗證/測試期間添加正則化損失,則損失值和曲線將看起來更加相似。
  2. 原因2:訓練損失是在每個epoch期間測量的,而驗證損失是在每個epoch後測量的。平均而言,訓練損失的測量時間是前一個時期的1/2。如果將訓練損失曲線向左移動半個epoch,則損失會更好。
  3. 原因3:您的驗證集可能比訓練集更容易,或者代碼中的數據/錯誤泄漏。確保您的驗證集大小合理,並且是從與您的訓練集相同的分佈(和難度)中抽取的。
  4. 獎勵:您的模型可能over-regularizing 。嘗試減少正則化約束,包括增加模型容量(即通過更多參數使其更深),減少dropout,降低L2權重衰減強度等。

希望這有助於消除對爲什麼您的驗證損失可能低於培訓損失的困惑!

英文原文鏈接:https://www.pyimagesearch.com/2019/10/14/why-is-my-validation-loss-lower-than-my-training-loss/

發佈了52 篇原創文章 · 獲贊 110 · 訪問量 25萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章