PyTorch梯度裁剪避免訓練loss nan

近來在訓練檢測網絡的時候會出現loss爲nan的情況,需要中斷重新訓練,會很麻煩。因而選擇使用PyTorch提供的梯度裁剪庫來對模型訓練過程中的梯度範圍進行限制,修改之後,不再出現loss爲nan的情況。

PyTorch中採用torch.nn.utils.clip_grad_norm_來實現梯度裁剪,鏈接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

訓練代碼使用示例如下:

from torch.nn.utils import clip_grad_norm_

outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()

# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)

optimizer.step()

其中,max_norm爲梯度的最大範數,也是梯度裁剪時主要設置的參數。


備註:網上有同學提醒在(強化學習)使用了梯度裁剪之後訓練時間會大大增加。目前在我的檢測網絡訓練中暫時還沒有碰到這個問題,以後遇到再來更新。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章