pytorch 的 torch.distributions 中可以定義正態分佈:
import torch
from torch.distributions import Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)
sample()
就是直接在定義的正太分佈(均值爲mean,標準差std是1)上採樣:
result = normal.sample()
print("sample():",result)
輸出:
sample(): tensor([-1.3362, 3.1730])
rsample()
不是在定義的正太分佈上採樣,而是先對標準正太分佈 N(0,1) 進行採樣,然後輸出: mean + std × 採樣值
result = normal.rsample()
print("rsample():",result)
輸出:
rsample: tensor([ 0.0530, 2.8396])
log_prob(value)
是計算value在定義的正態分佈(mean,1)中對應的概率的對數,正太分佈概率密度函數是:
對其取對數可得:
這裏我們通過對數概率還原其對應的真實概率:
print("result log_prob:",normal.log_prob(result).exp())
輸出:
result log_prob: tensor([ 0.1634, 0.2005])