tensorflow 構造對角線爲1的矩陣,並mask

import tensorflow as tf
sess = tf.Session()

input = tf.ones([2,3,3])*2
mask = tf.diag(tf.ones([3]))

print(sess.run(mask))
print(sess.run(input * mask))

print結果:
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]

[[[2. 0. 0.]
[0. 2. 0.]
[0. 0. 2.]]
[[2. 0. 0.]
[0. 2. 0.]
[0. 0. 2.]]]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章