TensorFlow 構造sequence的0/1 mask

import tensorflow as tf
bert_input_ids = tf.constant([[1,2,3,0],[1,2,0,0]])
sequence_len = tf.reduce_sum(tf.sign(bert_input_ids), reduction_indices=1)
sequence_len = tf.cast(sequence_len, tf.int32)
bert_mask_ids = tf.sequence_mask(sequence_len,4,tf.int32)
sess = tf.Session()
print(sess.run(bert_mask_ids))

print結果:

[[1 1 1 0]
 [1 1 0 0]]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章