js-divergence的pytorch實現

參考文檔

JSD實現代碼

若有紕漏,敬請指出,感謝!

def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    KLDivLoss = nn.KLDivLoss(reduction='batchmean')
    if get_softmax:
        p_output = F.softmax(p_output)
        q_output = F.softmax(q_output)
    log_mean_output = ((p_output + q_output )/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2

一些注意事項

  1. 關於dlv函數的使用:

    函數中的 p q 位置相反(也就是想要計算D(p||q),要寫成kl_div(q.log(),p)的形式),而且q要先取 log

  2. JS 散度度量了兩個概率分佈的相似度,基於KL散度的變體,解決了KL散度非對稱的問題。所以jsd(q, p)與jsd(p, q)一致。
    在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章