Pytorch: torch.distributions庫

pytorch 的 torch.distributions 中可以定義正態分佈:

import torch
from torch.distributions import  Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)

sample()就是直接在定義的正太分佈(均值爲mean,標準差std是1)上採樣:

result = normal.sample()
print("sample():",result)

輸出:

sample(): tensor([-1.3362,  3.1730])

rsample()不是在定義的正太分佈上採樣,而是先對標準正太分佈 N(0,1) 進行採樣,然後輸出: mean + std × 採樣值

result = normal.rsample()
print("rsample():",result)

輸出:

rsample: tensor([ 0.0530,  2.8396])

log_prob(value) 是計算value在定義的正態分佈(mean,1)中對應的概率的對數,正太分佈概率密度函數是:
在這裏插入圖片描述
對其取對數可得:
在這裏插入圖片描述
這裏我們通過對數概率還原其對應的真實概率:

print("result log_prob:",normal.log_prob(result).exp())

輸出:

result log_prob: tensor([ 0.1634,  0.2005])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章