TensorFlow之SparseTensor對象

在TensorFlow中,SparseTensor對象表示稀疏矩陣。SparseTensor對象通過3個稠密矩陣indices, values及dense_shape來表示稀疏矩陣,這三個稠密矩陣的含義介紹如下:

1. indices:數據類型爲int64的二維Tensor對象,它的Shape爲[N, ndims]。indices保存的是非零值的索引,即稀疏矩陣中除了indices保存的位置之外,其他位置均爲0。

2. values:一維Tensor對象,其Shape爲[N]。它對應的是稀疏矩陣中indices索引位置中的值。

3. dense_shape:數據類型爲int64的一維Tensor對象,其維度爲[ndims],用於指定當前稀疏矩陣對應的Shape。

舉個例子,稀疏矩陣對象SparseTensor(indices=[[0, 0],[1, 1]], value=[1, 1], dense_shape=[3, 3])對應的矩陣如下:

 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]

函數tf.sparse_tensor_to_dense用於將稀疏矩陣SparseTensor對象轉換爲稠密矩陣,函數tf.sparse_tensor_to_dense的原型如下:

tf.sparse_tensor_to_dense(
     sp_input,
     default_value=0,
     validate_indices=True,
     name=None
)

各個參數的類型及其含義介紹如下:

sp_input: SparseTensor對象,用於作爲轉換稠密矩陣的輸入。

default_value: 標量類型。稀疏矩陣sp_input中的indices沒有指定位置的元素值,默認爲0。

validate_indices: bool類型,用於設置是否對索引值按照字典順序排序。

name: string類型。返回的Tensor對象的名稱的前綴。

 

示例代碼:

import tensorflow as tf

# 定義Tensor對象
indices_tf = tf.constant([[0, 0], [1, 1]], dtype=tf.int64)
values_tf = tf.constant([1, 2], dtype=tf.float32)
dense_shape_tf = tf.constant([3, 3], dtype=tf.int64)

sparse_tf = tf.SparseTensor(indices=indices_tf,
                            values=values_tf,
                            dense_shape=dense_shape_tf)
dense_tf = tf.sparse_tensor_to_dense(sparse_tf, default_value=0)
with tf.Session() as sess:
    sparse, dense = sess.run([sparse_tf, dense_tf])
    print('sparse:\n', sparse)
    print('dense:\n', dense)

輸出:

sparse:
 SparseTensorValue(indices=array([[0, 0],
       [1, 1]], dtype=int64), values=array([1., 2.], dtype=float32), dense_shape=array([3, 3], dtype=int64))

dense:
 [[1. 0. 0.]
 [0. 2. 0.]
 [0. 0. 0.]]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章