import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell
cell = LSTMCell(128,state_is_tuple=False)
init_state = tf.random_normal([16,128]) # 可以是上一模型的輸出
output,new_state = cell(inputs=init_state,state=init_state)
for i in range(20):
output,new_state = cell(inputs=output,state=init_state) # 這個state=也試試傳output
print()
或者
import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell
cell = LSTMCell(128,state_is_tuple=False)
init_state = tf.random_normal([16,128]) # 可以是上一模型的輸出
output,new_state = cell(inputs=init_state,state=init_state) # new_state是[batch_size,256]
mlp = tf.keras.layers.Dense(units=128)
new_state = mlp(new_state)
for i in range(20):
output,new_state = cell(inputs=output,state=new_state)
new_state = mlp(new_state)
print()