【pytorch】踩坑PyTorch中的dropout

作者:雷傑
鏈接:https://www.zhihu.com/question/67209417/answer/302434279
來源:知乎
著作權歸作者所有。商業轉載請聯繫作者獲得授權,非商業轉載請註明出處。


剛踩的坑, 差點就哭出來了TT. — 我明明加了一百個dropout, 爲什麼結果一點都沒變
使用F.dropout ( nn.functional.dropout )的時候需要設置它的training這個狀態參數與模型整體的一致.
比如:

Class DropoutFC(nn.Module):
    def __init__(self):
        super(DropoutFC, self).__init__()
        self.fc = nn.Linear(100,20)

    def forward(self, input):
        out = self.fc(input)
        out = F.dropout(out, p=0.5)
        return out

Net = DropoutFC()
Net.train()
# train the Net

這段代碼中的F.dropout實際上是沒有任何用的, 因爲它的training狀態一直是默認值False. 由於F.dropout只是相當於引用的一個外部函數, 模型整體的training狀態變化也不會引起F.dropout這個函數的training狀態發生變化. 所以, 此處的out = F.dropout(out) 就是 out = out.
正確的使用方法如下, 將模型整體的training狀態參數傳入dropout函數

Class DropoutFC(nn.Module):
   def __init__(self):
       super(DropoutFC, self).__init__()
       self.fc = nn.Linear(100,20)

   def forward(self, input):
       out = self.fc(input)
       out = F.dropout(out, p=0.5, training=self.training)
       return out

Net = DropoutFC()
Net.train()
# train the Net

或者直接使用nn.Dropout() (nn.Dropout()實際上是對F.dropout的一個包裝, 也將self.training傳入了)

Class DropoutFC(nn.Module):
  def __init__(self):
      super(DropoutFC, self).__init__()
      self.fc = nn.Linear(100,20)
      self.dropout = nn.Dropout(p=0.5)

  def forward(self, input):
      out = self.fc(input)
      out = self.dropout(out)
      return out
Net = DropoutFC()
Net.train()
# train the Net
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章