tf.split()
顧名思義就是將tensor分割成爲列表的形式。通常tf.split之後往往會跟tf.concat結合使用。
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
value:準備切分的張量
num_or_size_splits:準備切成幾份
axis : 準備在第幾個維度上進行切割
其中分割方式分爲兩種
- 如果num_or_size_splits 傳入的 是一個整數,那直接在axis=D這個維度上把張量平均切分成幾個小張量
- 如果num_or_size_splits 傳入的是一個向量(這裏向量各個元素的和要跟原本這個維度的數值相等)就根據這個向量有幾個元素分爲幾項)
舉個例子:
# 張量爲(5, 30)
# 這個時候5是axis=0, 30是axis=1,如果要在axis=1這個維度上把這個張量拆分成三個子張量
#傳入向量時
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# 傳入整數時
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]
tf.tile()
此函數的作用就是將tensor在某一維度上進行擴展,就是將原來的tensor複製多次,然後拼接在制定的axis上。
tf.tile(
input, multiples, name=None
)
舉個例子: