1.OneHotCategorical
torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)
根據給定的概率probs, 創建一個 one-hot 的類別分佈.
m = OneHotCategorical(torch.tensor([ 0.1, 0.0, 0.9, 0.0 ]))
m.sample() # equal probability of 0, 1, 2, 3
#tensor([0., 0., 1., 0.])
參考:
pytorch distributions
————————————————
版權聲明:本文爲CSDN博主「rosefunR」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/rosefun96/article/details/103804510