tf.multinomial()/tf.random.categorical()用法解析
首先說一下,tf.multinomial()在tensorflow2.0版本已經被移除,取而代之的就是tf.random.categorical()
網上的很多博客解釋的都不清楚,官網......解釋的也很模糊,於是想自我總結一下,順便幫助對此也很困惑的人~
因爲tf.multinomial()被tf.random.categorical()替代,所以下文以tf.random.categorical()爲描述方式進行介紹
官網的解釋
tf.random.categorical
從一個分類分佈中抽取樣本(tf.multinomial()是多項分佈)
別名:
tf.compat.v1.random.categorical
tf.compat.v2.random.categorical
tf.random.categorical(
logits,
num_samples,
dtype=None,
seed=None,
name=None
)
例子:
# samples has shape [1, 5], where each value is either 0 or 1 with equal
# probability.
samples = tf.random.categorical(tf.math.log([[10., 10.]]), 5)
參數:
logits
: 形狀爲[batch_size, num_classes]的張量
. 每個切片[i, :]
代表對於所有類的未正規化的log概率。num_samples
: 0維,從每一行切片中抽取的獨立樣本的數量。dtype
: 用於輸出的整數類型,默認爲int64。seed
: 一個Python整數,用於創建分佈的隨機種子。Seetf.compat.v1.set_random_seed
for behavior.name
: 操作的可選名字
Returns:
形狀爲[batch_size, num_samples]的抽取樣本
.
個人理解
1. 這個函數的意思就是,你給了一個batch_size × num_classes的矩陣,這個矩陣是這樣的:
每一行相當於log(p(x)),這裏假設p(x)=[0.4,0.3,0.2,0.1],(p(x)的特性就是和爲1),
然後再取log,那麼log(p(x))就等於[-0.9162907 -1.20397282 -1.60943794 -2.30258512]
函數利用你給的分佈概率,從其中的每一行中抽取num_samples次,最終形成的矩陣就是batch_szie × num_samples了。
2. 這裏的抽樣方法可以再詳細解釋一下,舉個例子(請不要考慮真實性),給一行[1.0,2.0,2.0,2.0,6.0],採樣4次,那麼結果很大可能都是[4,4,4,4](不信可以試一下),因爲下標爲4的概率(6.0)遠遠高於其他的概率,當然也會出現比如[4,4,2,4]這樣的情況,就是說其他的下標因爲給定的概率就低,所以被採樣到的概率也就低了。
3. 官網解釋中logits,也就是你給的矩陣,每個切片 [i, :]
代表對於所有類的未正規化的log概率(即其和不爲1),但必須是小數,就像官網的樣例一樣,就算是整數,後面也要加一個小數點,否則會報錯。
4. 返回值是什麼的問題,返回的其實不是抽取到的樣本,而是抽取樣本在每一行的下標。
爲了能更加充分的理解,下面奉上一個小小的例子:
import tensorflow as tf;
for i in tf.range(10):
samples = tf.random.categorical([[1.0,1.0,1.0,1.0,4.0],[1.0,1,1,1,1]], 6)
tf.print(samples)
輸出結果
[[4 4 4 4 4 1]
[3 1 3 0 4 3]]
[[4 0 4 4 4 1]
[1 0 2 4 1 2]]
[[0 4 4 0 4 4]
[3 0 0 1 1 4]]
[[4 4 4 4 4 0]
[2 1 4 3 4 4]]
[[4 4 2 4 4 4]
[1 3 1 0 4 0]]
[[4 4 4 4 4 4]
[3 0 4 1 1 1]]
[[4 4 0 0 4 4]
[3 3 0 3 2 2]]
[[1 4 4 4 4 4]
[2 2 1 3 0 2]]
[[4 4 4 4 4 4]
[2 4 4 3 2 2]]
[[4 4 4 4 3 4]
[2 4 2 2 1 0]]
看到這估計你就能理解了,其中[[1.0,1.0,1.0,1.0,4.0],[1.0,1,1,1,1]]就是需要進行採樣的矩陣,這裏加小數點其實可以只加一個,只要讓程序知道你用的是概率就行(當然實際都是通過tf.log()得到的不用手動輸入),輸出結果自然就是樣本所在行的下標,多運行幾次,就能更直觀的感受到,設定的概率和採樣結果之間的關係。(比如這裏第一行的採樣結果很多都是最後一個樣本,第二行因爲概率相同,採樣結果就很均勻)
就這麼多啦,如果文章有錯誤或者有疑問歡迎評論區交流呀(●’◡’●)ノ ~