參考文檔
JSD實現代碼
若有紕漏,敬請指出,感謝!
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='batchmean')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output )/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
一些注意事項
-
關於dlv函數的使用:
函數中的 p q 位置相反(也就是想要計算D(p||q),要寫成kl_div(q.log(),p)的形式),而且q要先取 log
-
JS 散度度量了兩個概率分佈的相似度,基於KL散度的變體,解決了KL散度非對稱的問題。所以jsd(q, p)與jsd(p, q)一致。