【Tensorflow】RNN常用函數

整理自:TensorFlow中RNN實現的正確打開方式:https://blog.csdn.net/starzhou/article/details/77848156

RNN的基本單元“RNNcell”

  • (output, next_state) = call(input, state)。
    • 每調用一次RNNCell的call方法,就相當於在時間上“推進了一步”,這就是RNNCell的基本功能。
    • 執行一次,序列時間上前進一步。
    • 有兩個子類:BasicRNNCellBasicLSTMCell
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print(cell.state_size) # 隱藏層的大小:128

inputs = tf.placeholder(np.float32,shape=(32,100)) # 32爲batch_size
h0 = cell.zero_state(32,np.float32) #初始狀態爲全0
output,h1=cell.call(input,h0) #調用call函數
print(h1.shape) #(32,128)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size
h0 = lstm_cell.zero_state(32, np.float32) # 通過zero_state得到一個全0的初始狀態
output, h1 = lstm_cell.call(inputs, h0) #都是(32,128)

一次執行多步

  • tf.nn.dynamic_rnn
    • 相當於調用了n次call函數。
    • time_steps:序列長度。
    • outputs是time_steps步裏所有的輸出。形狀爲(batch_size, time_steps, cell.output_size)。state是最後一步的隱狀態,它的形狀爲(batch_size, cell.state_size)。
inputs = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=128) 

seq_length = tf.placeholder(tf.int32, [None]) # 序列長度
outputs, states = tf.nn.dynamic_rnn(basic_cell, inputs, dtype=tf.float32,sequence_length=seq_length)

如何堆疊RNNCell:MultiRNNCell

  • tf.nn.rnn_cell.MultiRNNCell
    • 實現多層RNN
def get_a_cell():
	return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _in range(3)]) # 3層RNN
print(cell.state_size) #(128,128,128) # 並不是128x128x128,而是每個隱層狀態大小爲128

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size
h0 = cell.zero_state(32, np.float32) # 通過zero_state得到一個全0的初始狀態
output, h1 = cell.call(inputs, h0)

print(h1) # tuple中含有3個32x128的向量
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章