近來在訓練檢測網絡的時候會出現loss爲nan的情況,需要中斷重新訓練,會很麻煩。因而選擇使用PyTorch提供的梯度裁剪庫來對模型訓練過程中的梯度範圍進行限制,修改之後,不再出現loss爲nan的情況。
PyTorch中採用torch.nn.utils.clip_grad_norm_來實現梯度裁剪,鏈接如下:
https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
訓練代碼使用示例如下:
from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()
其中,max_norm爲梯度的最大範數,也是梯度裁剪時主要設置的參數。
備註:網上有同學提醒在(強化學習)使用了梯度裁剪之後訓練時間會大大增加。目前在我的檢測網絡訓練中暫時還沒有碰到這個問題,以後遇到再來更新。