focal loss用來解決樣本不均衡的分類問題。
假設正樣本(label=1)少,負樣本多,定義focal loss如下
Loss = -[alpha*(1-y_hat)^2yln(y_hat)
+ (1-alpha)y_hat^2(1-y)*ln(1-y_hat)]
其中y_hat:(batch, seq, tags),預測出的
y: (batch, seq, tags)
alpha:(1, 1, tags)
alpha是超參數,是正樣本損失佔總體的比例,初始化爲 少數樣本/總樣本 的比值,調整策略如下,正樣本的precision<recall時,訓練更關注正樣本,alpha調低,反之調高。
調整策略也可以爲:
正類的識別正確率與負類的識別正確率