tensorflow2.x實現兩個多元高斯分佈之間的KL散度,很重要

0.背景

  • 現在假設你要用tensorflow計算兩個多元高斯分佈之間的KL散度,用閉式解,該如何用tensorflow2.x實現。
    在這裏插片描述
  • 看到這個公式,相比大家都是頭疼的,尤其在訓練時候,還要考慮Batch的維度。今天就用tensorflow實現一下。

1. tensorflow矩陣操作

1.1 多維矩陣的乘法

  • 一般我們都考慮二維矩陣的乘法,只需要注意兩個矩陣的維度即可。但是,有的時候,我們還需要考慮例如Batch_size怎麼搞,這是這一小結要解決的問題。

1.1.1 tf.matmul函數

點我

  • a_is_sparse或者b_is_sparse只有在rank=2時纔有用,不然會報錯
  • 我們測試二維和三維的時候,看看這個操作是否只考慮用後兩個維度進行矩陣乘積,而其它保持不變,並且維度大小相等
  • 要注意的是:兩個矩陣的數據類型要一致,不然會報錯
# 2-D tensor `a`
a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) => [[1. 2. 3.]
                                                      [4. 5. 6.]]
# 2-D tensor `b`
b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) => [[7. 8.]
                                                         [9. 10.]
                                                         [11. 12.]]
c = tf.matmul(a, b) shape=(2,2) => [[58 64]
                                   [139 154]]

# 3-D tensor `b`
b = tf.constant(np.arange(13, 25, dtype=np.int32),
                shape=[2, 3, 2])                   => [[[13. 14.]
                                                        [15. 16.]
                                                        [17. 18.]],
                                                       [[19. 20.]
                                                        [21. 22.]
                                                        [23. 24.]]]
c = tf.matmul(a, b) shape(2,2,2)=> [[[ 94 100]
                                    [229 244]],
                                    [[508 532]
                                    [697 730]]]
  • 結論:在進行矩陣乘積,我們說的時點乘,不是逐元素乘積(use “⭐”),使用matmul函數,該函數只考慮最後兩個維度,其餘維度不變同時要求兩個矩陣的這些維度大小相等,如果我們想探究一下源碼,OK:

  • 答案是,當矩陣rank>2時,會自動調用batch_mat_mul函數

    if (not a_is_sparse and
        not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
                              (b_shape is None or len(b_shape) > 2)):
      # BatchMatmul does not support transpose, so we conjugate the matrix and
      # use adjoint instead. Conj() is a noop for real matrices.
      if transpose_a:
        a = conj(a)
        adjoint_a = True
      if transpose_b:
        b = conj(b)
        adjoint_b = True
      return gen_math_ops.batch_mat_mul(
          a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
 
 
# 上述代碼中batch_mat_mul的定義
def batch_mat_mul(x, y, adj_x=False, adj_y=False, name=None):
  r"""Multiplies slices of two tensors in batches.
  Multiplies all slices of `Tensor` `x` and `y` (each slice can be
  viewed as an element of a batch), and arranges the individual results
  in a single output tensor of the same batch size. Each of the
  individual slices can optionally be adjointed (to adjoint a matrix
  means to transpose and conjugate it) before multiplication by setting
  the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
  
  The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
  and `[..., r_y, c_y]`.
  
  The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
  
      r_o = c_x if adj_x else r_x
      c_o = r_y if adj_y else c_y
  
  It is computed as:
  
      output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
  Args:
    x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
      2-D or higher with shape `[..., r_x, c_x]`.
    y: A `Tensor`. Must have the same type as `x`.
      2-D or higher with shape `[..., r_y, c_y]`.
    adj_x: An optional `bool`. Defaults to `False`.
      If `True`, adjoint the slices of `x`. Defaults to `False`.
    adj_y: An optional `bool`. Defaults to `False`.
      If `True`, adjoint the slices of `y`. Defaults to `False`.
    name: A name for the operation (optional).
  Returns:
    A `Tensor`. Has the same type as `x`.
  """

1.1.2 使用 @ 重載函數

  • @ 函數是一個重載運算符號,以下是它的解釋

  # Since python >= 3.5 the @ operator is supported (see PEP 465).
  # In TensorFlow, it simply calls the `tf.matmul()` function, so the
  # following lines are equivalent:
  d = a @ b @ [[10.], [11.]]
  d = tf.matmul(tf.matmul(a, b), [[10.], [11.]])

1.2 多維矩陣的轉置

  • 設想以下,我們現在得到了一個均值向量μ,shape=(Batch,μ_dim),對於Batch中的每一個μ,我們都希望得到它的轉置,即得到shape=(Batch,1,μ_dim)或者(Batch,μ_dim,1),我們要做的有兩步:

1.2.1 用tf.expand_dims擴展維度

tf.expand_dims(a,axis=),該函數將指定的tensor a在指定的維度上增加一個維度,置爲1,舉個例子:

a = tf.ones(shape=(10,5))
a = tf.expand_dim(a,axis=-1)

print(a.shape)
>>> (10,5,1)

所以我們可以使用該函數將
(B,μ_dim) ------>(B,μ_dim,1),方便我們後續的矩陣操作

1.2.2 tf.squeeze(a,axis=)

  • 該函數是和tf.expand_dims相反的函數,去掉維度爲1的維度
a = tf.ones(shape=(1,1,2,2))
a = tf.squeeze(a,# axis= 指定維度)
a.shape

# result
(2,2)

1.2.3 矩陣轉置

方法一:使用tensorflow1.x版本的tf.compat.v1.matrix_transpose函數

  • 它會轉置張量的最後兩個維度,很方便,很適合在rank>2的情況下使用
  • tf.compat.v1是在使用tensorflow2.x時,用這個函數兼容tensorflow1.x的函數庫,英文是compatiable version1
matrix_transpose(
    a,
    name='matrix_transpose'
)
# Matrix with no batch dimension.
# 'x' is [[1 2 3]
#         [4 5 6]]
tf.matrix_transpose(x) ==> [[1 4]
                                 [2 5]
                                 [3 6]]

Matrix with two batch dimensions.
x.shape is [1, 2, 3, 4]
tf.compat.v1.matrix_transpose(x) is shape [1, 2, 4, 3]

方法二:使用tensorflow2.x版本的tf.transpose()函數

建議參考下面這個博客:
https://blog.csdn.net/qq_40994943/article/details/85270159

1.3 求矩陣的行列式

  • 使用tf.compat.v1.matrix_determinant()函數
matrix_determinant(
    input,
    name=None
)

我們知道只有滿秩的矩陣纔有行列式,所以input必須滿足形狀爲:[…, M, M],這個函數可以幫助我們計算一個batch的行列式。

a = tf.ones(shape=(5,4,4))
det_a = tf.compat.v1.matrix_determinant(a)
print(det_a)

# result
<tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 0., 0., 0., 0.], dtype=float32)>

1.4 求矩陣的逆

1.4.1 tf.matrix_inverse()

在這裏插入圖片描述

  • 在深度學習領域,我們一般都考慮特徵變量之間的協方差爲0,所以我們只需要得到矩陣的對角線元素,然後用對角線元素去構建矩陣即可。
  • 另外需要注意的是,在進行矩陣的逆操作時,要保證矩陣可逆

1.4.2 tf.compat.v1.matrix_diag()

2. 聯合起來,就能更強

將上面的這些操作聯合起來,我們就能計算出兩個高斯分佈之間的KL散度

def compute_kl(u1,sigma1,u2,sigma2,dim):
    """
    計算兩個多元高斯分佈之間KL散度KL(N1||N2);
    
    所有的shape均爲(B1,B2,...,dim),表示協方差爲0的多元高斯分佈
    這裏我們假設加上Batch_size,即形狀爲(B,dim)
    
    dim:特徵的維度
    """
    sigma1_matrix = tf.compat.v1.matrix_diag(sigma1) # (B,dim,dim)
    sigma1_matrix_det = tf.compat.v1.matrix_determinant(sigma1_matrix) # (B,)
    
    sigma2_matrix = tf.compat.v1.matrix_diag(sigma2) # (B,dim,dim)
    sigma2_matrix_det = tf.compat.v1.matrix_determinant(sigma2_matrix) # (B,)
    sigma2_matrix_inv = tf.compat.v1.matrix_diag(1./sigma2) # (B,dim,dim)
    
    delta_u = tf.expand_dims((u2-u1),axis=-1) # (B,dim,1)
    delta_u_transpose = tf.compat.v1.matrix_transpose(delta_u) # (B,1,dim)
    
    term1 = tf.reduce_sum((1./sigma2)*sigma1,axis=-1) # (B,) represent trace term
    term2 = delta_u_transpose @ sigma2_matrix_inv @ delta_u  # (B,)
    term3 = -dim
    term4 = tf.math.log(sigma2_matrix_det) - tf.math.log(sigma1_matrix_det)
    
    KL = 0.5 * (term1 + term2 + term3 + term4)
    
    # if you want to compute the mean of a batch,then,
    KL_mean = tf.reduce_mean(KL)
    
    return KL_mean

# 測試
dim = 5
u1 = tf.zeros(shape=(10,5))
sigma1 = tf.ones(shape=(10,5))
u2 = tf.zeros(shape=(10,5))
sigma2 = tf.ones(shape=(10,5))

dim = 5
u1 , sigma1 = tf.zeros(shape=(10,5)),tf.ones(shape=(10,5)) # N(0,I)
u2 ,sigma2 = tf.zeros(shape=(10,5)),tf.ones(shape=(10,5))  # N(0,I)
u3 ,sigma3 = tf.zeros(shape=(10,5)),4*tf.ones(shape=(10,5)) # N(0,4I)

KL1 = compute_kl(u1,sigma1,u2,sigma2,dim)
KL2 = compute_kl(u1,sigma1,u3,sigma3,dim)

print(KL1,"\n",KL2,sep='')

# result
# tf.Tensor(0.0, shape=(), dtype=float32)
# tf.Tensor(1.5907359, shape=(), dtype=float32)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章