tf.multinomial()/tf.random.categorical()用法解析

tf.multinomial()/tf.random.categorical()用法解析

首先說一下,tf.multinomial()在tensorflow2.0版本已經被移除,取而代之的就是tf.random.categorical()

網上的很多博客解釋的都不清楚,官網......解釋的也很模糊,於是想自我總結一下,順便幫助對此也很困惑的人~

因爲tf.multinomial()被tf.random.categorical()替代,所以下文以tf.random.categorical()爲描述方式進行介紹

官網的解釋

tf.random.categorical

從一個分類分佈中抽取樣本(tf.multinomial()是多項分佈)

別名:

  • tf.compat.v1.random.categorical
  • tf.compat.v2.random.categorical
tf.random.categorical(
    logits,
    num_samples,
    dtype=None,
    seed=None,
    name=None
)

例子:

# samples has shape [1, 5], where each value is either 0 or 1 with equal
# probability.
samples = tf.random.categorical(tf.math.log([[10., 10.]]), 5)

參數:

  • logits: 形狀爲 [batch_size, num_classes]的張量. 每個切片 [i, :] 代表對於所有類的未正規化的log概率。
  • num_samples: 0維,從每一行切片中抽取的獨立樣本的數量。
  • dtype: 用於輸出的整數類型,默認爲int64。
  • seed: 一個Python整數,用於創建分佈的隨機種子。See tf.compat.v1.set_random_seedfor behavior.
  • name: 操作的可選名字

Returns:

形狀爲[batch_size, num_samples]的抽取樣本.

個人理解

1. 這個函數的意思就是,你給了一個batch_size × num_classes的矩陣,這個矩陣是這樣的:
每一行相當於log(p(x)),這裏假設p(x)=[0.4,0.3,0.2,0.1],(p(x)的特性就是和爲1),
然後再取log,那麼log(p(x))就等於[-0.9162907 -1.20397282 -1.60943794 -2.30258512]
函數利用你給的分佈概率,從其中的每一行中抽取num_samples次,最終形成的矩陣就是batch_szie × num_samples了。

2. 這裏的抽樣方法可以再詳細解釋一下,舉個例子(請不要考慮真實性),給一行[1.0,2.0,2.0,2.0,6.0],採樣4次,那麼結果很大可能都是[4,4,4,4](不信可以試一下),因爲下標爲4的概率(6.0)遠遠高於其他的概率,當然也會出現比如[4,4,2,4]這樣的情況,就是說其他的下標因爲給定的概率就低,所以被採樣到的概率也就低了。

3. 官網解釋中logits,也就是你給的矩陣,每個切片 [i, :] 代表對於所有類的未正規化的log概率(即其和不爲1),但必須是小數,就像官網的樣例一樣,就算是整數,後面也要加一個小數點,否則會報錯。

4. 返回值是什麼的問題,返回的其實不是抽取到的樣本,而是抽取樣本在每一行的下標。

爲了能更加充分的理解,下面奉上一個小小的例子:

import tensorflow as tf;
for i in tf.range(10):
    samples = tf.random.categorical([[1.0,1.0,1.0,1.0,4.0],[1.0,1,1,1,1]], 6)
    tf.print(samples)

 輸出結果

[[4 4 4 4 4 1]
 [3 1 3 0 4 3]]
[[4 0 4 4 4 1]
 [1 0 2 4 1 2]]
[[0 4 4 0 4 4]
 [3 0 0 1 1 4]]
[[4 4 4 4 4 0]
 [2 1 4 3 4 4]]
[[4 4 2 4 4 4]
 [1 3 1 0 4 0]]
[[4 4 4 4 4 4]
 [3 0 4 1 1 1]]
[[4 4 0 0 4 4]
 [3 3 0 3 2 2]]
[[1 4 4 4 4 4]
 [2 2 1 3 0 2]]
[[4 4 4 4 4 4]
 [2 4 4 3 2 2]]
[[4 4 4 4 3 4]
 [2 4 2 2 1 0]]

看到這估計你就能理解了,其中[[1.0,1.0,1.0,1.0,4.0],[1.0,1,1,1,1]]就是需要進行採樣的矩陣,這裏加小數點其實可以只加一個,只要讓程序知道你用的是概率就行(當然實際都是通過tf.log()得到的不用手動輸入),輸出結果自然就是樣本所在行的下標,多運行幾次,就能更直觀的感受到,設定的概率和採樣結果之間的關係。(比如這裏第一行的採樣結果很多都是最後一個樣本,第二行因爲概率相同,採樣結果就很均勻)

就這麼多啦,如果文章有錯誤或者有疑問歡迎評論區交流呀(●’◡’●)ノ ~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章