import tensorflow as tf
bert_input_ids = tf.constant([[1,2,3,0],[1,2,0,0]])
sequence_len = tf.reduce_sum(tf.sign(bert_input_ids), reduction_indices=1)
sequence_len = tf.cast(sequence_len, tf.int32)
bert_mask_ids = tf.sequence_mask(sequence_len,4,tf.int32)
sess = tf.Session()
print(sess.run(bert_mask_ids))
print結果:
[[1 1 1 0]
[1 1 0 0]]