個人其他鏈接
資源
-
完整代碼+詳細代碼註釋:github
原理
Transformer模型來自論文Attention Is All You Need。這個模型的應用場景是機器翻譯,藉助Self-Attention機制和Position Encoding可以替代傳統Seq2Seq模型中的RNN結構。由於Transformer的優異表現,後續OpenAI GPT和BERT模型都使用了Transformer的Decoder部分。
Transformer算法流程:
輸入:inputs, targets
舉個例子:
inputs = ‘SOS 想象力 比 知識 更 重要 EOS’
targets = ‘SOS imagination is more important than knowledge EOS’
訓練
訓練時採用強制學習
inputs = ‘SOS 想象力 比 知識 更 重要 EOS’
targets = ‘SOS imagination is more important than knowledge’
目標(targets)被分成了 tar_inp 和 tar_real。tar_inp 作爲輸入傳遞到Decoder。tar_real 是位移了 1 的同一個輸入:在 tar_inp 中的每個位置,tar_real 包含了應該被預測到的下一個標記(token)。
tar_inp = ‘SOS imagination is more important than knowledge’
tar_real = ‘imagination is more important than knowledge EOS’
即inputs經過Encoder編碼後得到inputs的信息,targets開始輸入SOS 向後Decoder翻譯預測下一個詞的概率,由於訓練時採用強制學習,所以用真實值來預測下一個詞。
預測輸出
tar_pred = ‘imagination is more important than knowledge EOS’
當然這是希望預測最好的情況,即真實tar_real就是這樣。實際訓練時開始不會預測這麼準確
損失:交叉熵損失
根據tar_pred和tar_real得到交叉熵損失
模型訓練好後如何預測?
其中SOS爲標誌句子開始的標誌符號,EOS爲標誌結束的符號
Encoder階段:inputs = ‘SOS 想象力 比 知識 更 重要 EOS’
Decoder階段:循環預測
輸入一個[SOS, ],預測到下一個token爲:imagination
輸入[SOS, imagination], 預測下一個token爲:is
…
輸入[SOS, imagination is more important than knowledge]預測下一個EOS。最終結束
結束有兩個條件,預測到EOS,或者最長的target_seq_len
網絡結構
原始論文網絡結構
自己實現的網絡結構:
Encoder部分:
下面僞代碼中的解釋:
MultiHeadAttention(v, k, q, mask)
Encoder block
包括兩個子層:
- 多頭注意力(有填充遮擋)
- 點式前饋網絡(Point wise feed forward networks), 其實就是兩層全連接
輸入x爲input_sentents, (batch_size, seq_len, d_model)
- out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
- out2 = BatchNormalization( out1 + (ffn(out1) => dropout) )
Decoder部分:
和Encoder部分區別在於,Decoder部分先對自身做了Self-Attention後,在作爲query,對Encoder的輸出作爲key和value,進行普通Attention後的結果,作爲 feed forward的輸入
Decoder block,需要的子層:
- 遮擋的多頭注意力(前瞻遮擋和填充遮擋)
- 多頭注意力(用填充遮擋)。V(數值)和 K(主鍵)接收編碼器輸出作爲輸入。Q(請求)接收遮擋的多頭注意力子層的輸出。
- 點式前饋網絡
輸入x爲target_sentents, (batch_size, seq_len, d_model)
- out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
- out2 = BatchNormalization( out1 +(MultiHeadAttention(enc_output, enc_output out1)=>dropout))
- out3 = BatchNormalization( out2 + (ffn(out2) => dropout) )
具體代碼實現
Position
def get_angles(pos, i, d_model):
'''
:param pos:單詞在句子的位置
:param i:單詞在詞表裏的位置
:param d_model:詞向量維度大小
:return:
'''
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
'''
:param position: 最大的position
:param d_model: 詞向量維度大小
:return: [1, 最大position個數,詞向量維度大小] 最後和embedding矩陣相加
'''
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
point_wise_feed_forward_network
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])
Attention
其中MultiHeadAttention其實是在d_model(詞embedding維度)進行split,然後做Attention
def scaled_dot_product_attention(q, k, v, mask=None):
'''計算attention
q,k,v的第一維度必須相同
q,k的最後一維必須相同
k,v在倒數第二的維度需要相同, seq_len_k = seq_len_q=seq_len。
參數:
q: 請求的形狀 == (..., seq_len_q, d)
k: 主鍵的形狀 == (..., seq_len, d)
v: 數值的形狀 == (..., seq_len, d_v)
mask: Float 張量,其形狀能轉換成
(..., seq_len_q, seq_len)。默認爲None。
返回值:
輸出,注意力權重
'''
# (batch_size, num_heads, seq_len_q, d ) dot (batch_size, num_heads, d, seq_ken_k) = (batch_size, num_heads,, seq_len_q, seq_len)
matmul_qk = tf.matmul(q, k, transpose_b=True)
# 縮放matmul_qk
dk = tf.cast(tf.shape(k)[-1], dtype=tf.float32)
scaled_attention_logits = matmul_qk/tf.math.sqrt(dk)
# 將 mask 加入到縮放的張量上。
if mask is not None:
# (batch_size, num_heads,, seq_len_q, seq_len) + (batch_size, 1,, 1, seq_len)
scaled_attention_logits += (mask * -1e9)
# softmax歸一化權重 (batch_size, num_heads, seq_len)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
# seq_len_q個位置分別對應v上的加權求和
# (batch_size, num_heads, seq_len) dot (batch_size, num_heads, d_v) = (batch_size, num_heads, seq_len_q, d_v)
output = tf.matmul(attention_weights, v)
return output, attention_weights
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert (d_model > num_heads) and (d_model % num_heads == 0)
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.qw = tf.keras.layers.Dense(d_model)
self.kw = tf.keras.layers.Dense(d_model)
self.vw = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) # (batch_size, seq_len, num_heads, depth)
return tf.transpose(x, perm=(0, 2, 1, 3)) # (batch_size, num_heads, seq_len, depth)
def call(self, v, k, q, mask=None):
# v = inputs
batch_size = tf.shape(q)[0]
q = self.qw(q) # (batch_size, seq_len_q, d_model)
k = self.kw(k) # (batch_size, seq_len, d_model)
v = self.vw(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len, depth_v)
# scaled_attention, (batch_size, num_heads, seq_len_q, depth_v)
# attention_weights, (batch_size, num_heads, seq_len_q, seq_len)
scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=(0, 2, 1, 3)) # (batch_size, seq_len_q, num_heads, depth_v)
concat_attention = tf.reshape(scaled_attention, shape=(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
Encoder
輸入:
- inputs(batch_size, seq_len_inp, d_model)
- mask(batch_size, 1, 1, seq_len_inp),因爲輸入序列要填充到相同的長度,所以對填充的位置做self-attention時要做mask,這裏之所以是(batch_size, 1, 1, d_model)的維度,是因爲inputs做MultiHeadAttention會split成(batch_size, num_heads, seq_len_inp, d_model//num_heads),經過MultiHeadAttention計算的權重是(batch_size, num_heads, seq_len_inp, seq_len_inp ),這樣做mask時,mask會自動傳播成:(batch_size, num_heads, seq_len_inp, seq_len_inp )
輸出:
- encode_output(batch_size, seq_len_inp, d_model)
class EncoderLayer(tf.keras.layers.Layer):
'''Encoder block
包括兩個子層:1.多頭注意力(有填充遮擋)2.點式前饋網絡(Point wise feed forward networks)。
out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
out2 = BatchNormalization( out1 + (ffn(out1) => dropout) )
'''
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layer_norm1 = tf.keras.layers.BatchNormalization(epsilon=1e-6)
self.layer_norm2 = tf.keras.layers.BatchNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layer_norm1(x+attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layer_norm2(out1+ffn_output) # (batch_size, input_seq_len, d_model)
return out2
class Encoder(tf.keras.layers.Layer):
'''
輸入嵌入(Input Embedding)
位置編碼(Positional Encoding)
N 個編碼器層(encoder layers)
輸入經過嵌入(embedding)後,該嵌入與位置編碼相加。該加法結果的輸出是編碼器層的輸入。編碼器的輸出是解碼器的輸入。
'''
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.enc_layer = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
# x.shape == (batch_size, seq_len)
seq_len = tf.shape(x)[1]
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, dtype=tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layer[i](x, training, mask)
return x #(batch_size, input_seq_len, d_model)
Decoder
輸入:
- targets_inp(batch_size, seq_len_tar, d_model)
- encode_output(batch_size, seq_len_inp, d_model)
- self_mask(batch_size, 1, 1, seq_len_tar), enc_output_mask(batch_size, 1, 1, seq_len_inp)
輸出:
- decode_output(batch_size, seq_len_tar, tar_vobsize)
class DecoderLayer(tf.keras.layers.Layer):
''' Decoder block
需要的子層:
1.遮擋的多頭注意力(前瞻遮擋和填充遮擋)
2.多頭注意力(用填充遮擋)。V(數值)和 K(主鍵)接收編碼器輸出作爲輸入。Q(請求)接收遮擋的多頭注意力子層的輸出。
3. 點式前饋網絡
out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
out2 = BatchNormalization( out1 +(MultiHeadAttention(enc_output, enc_output out1)=>dropout))
out3 = BatchNormalization( out2 + (ffn => dropout) )
'''
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
# x.shape == (batch_size, target_seq_len, d_model)
# enc_output.shape == (batch_size, input_seq_len, d_model)
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layer_norm1(x+attn1)
attn2, attn_weights_block2 = self.mha1(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layer_norm2(out1+attn2)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layer_norm3(out2+ffn_output) # (batch_size, target_seq_len, d_model)
return out3, attn_weights_block1, attn_weights_block2
class Decoder(tf.keras.layers.Layer):
'''解碼器包括:
輸出嵌入(Output Embedding)
位置編碼(Positional Encoding)
N 個解碼器層(decoder layers)
目標(target)經過一個嵌入後,該嵌入和位置編碼相加。該加法結果是解碼器層的輸入。解碼器的輸出是最後的線性層的輸入。
'''
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.dec_layer = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
# x.shape==(batch_size, target_seq_len)
# enc_output.shape==(batch_size, input_seq_len, d_model)
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layer[i](x, enc_output, training, look_ahead_mask, padding_mask)
attention_weights['decoder_layer{}_block1'.format(i + 1)] = block1
attention_weights['decoder_layer{}_block2'.format(i + 1)] = block2
# x.shape==(batch_size, target_seq_len, d_model)
return x, attention_weights
Transformer
class Transformer(tf.keras.Model):
def __init__(self, params):
super(Transformer, self).__init__()
self.encoder = Encoder(params['num_layers'],params['d_model'],params['num_heads'],params['dff'],params['input_vocab_size'],params['pe_input'],params['rate'])
self.decoder = Decoder(params['num_layers'],params['d_model'],params['num_heads'],params['dff'],params['target_vocab_size'],params['pe_target'],params['rate'])
self.final_layer = tf.keras.layers.Dense(params['target_vocab_size'])
def call(self, inp, tar, training, enc_padding_mask=None, look_ahead_mask=None, dec_padding_mask=None):
# (batch_size, inp_seq_len, d_model)
enc_output = self.encoder(inp, training, enc_padding_mask)
# (batch_size, tar_seq_len, d_model)
dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
return final_output, attention_weights
Mask
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
# 添加額外的維度來將填充加到
# 注意力對數(logits)。
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
def create_look_ahead_mask(size):
'''
eg.
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp:<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
[0., 0., 1.],
[0., 0., 0.]], dtype=float32)>
'''
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
def create_masks(inp, tar):
# 編碼器填充遮擋
enc_padding_mask = create_padding_mask(inp)
# 在解碼器的第二個注意力模塊使用。
# 該填充遮擋用於遮擋編碼器的輸出。
dec_padding_mask = create_padding_mask(inp)
# 在解碼器的第一個注意力模塊使用。
# 用於填充(pad)和遮擋(mask)解碼器獲取到的輸入的後續標記(future tokens)。
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) #(tar_seq_len, tar_seq_len)
dec_target_padding_mask = create_padding_mask(tar) # (batch_size, 1, 1, tar_seq_len)
# 廣播機制,look_ahead_mask==>(batch_size, 1, tar_seq_len, tar_seq_len)
# dec_target_padding_mask ==> (batch_size, 1, tar_seq_len, tar_seq_len)
combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return enc_padding_mask, combined_mask, dec_padding_mask
組合最終
# ==============================================================
params = {
'num_layers':4,
'd_model':128,
'dff':512,
'num_heads':8,
'input_vocab_size' :tokenizer_pt.vocab_size + 2,
'target_vocab_size':tokenizer_en.vocab_size + 2,
'pe_input':tokenizer_pt.vocab_size + 2,
'pe_target':tokenizer_en.vocab_size + 2,
'rate':0.1,
'checkpoint_path':'./checkpoints/train',
'checkpoint_do_delete':False
}
print('input_vocab_size is {}, target_vocab_size is {}'.format(params['input_vocab_size'], params['target_vocab_size']))
class ModelHelper:
def __init__(self):
self.transformer = Transformer(params)
# optimizer
learning_rate = CustomSchedule(params['d_model'])
self.optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
# 主要爲了累計一個epoch中的batch的loss,最後求平均,得到一個epoch的loss
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
# 主要爲了累計一個epoch中的batch的acc,最後求平均,得到一個epoch的acc
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
# 檢查點 params['checkpoint_path']如果不存在,則創建對應目錄;如果存在,且checkpoint_do_delete=True時,則先刪除目錄在創建
checkout_dir(dir_path=params['checkpoint_path'], do_delete=params.get('checkpoint_do_delete', False))
# 檢查點
ckpt = tf.train.Checkpoint(transformer=self.transformer,
optimizer=self.optimizer)
self.ckpt_manager = tf.train.CheckpointManager(ckpt, params['checkpoint_path'], max_to_keep=5)
# 如果檢查點存在,則恢復最新的檢查點。
if self.ckpt_manager.latest_checkpoint:
ckpt.restore(self.ckpt_manager.latest_checkpoint)
print('Latest checkpoint restored!!')
def loss_function(self, real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = self.loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(self, inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
with tf.GradientTape() as tape:
predictions, _ = self.transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = self.loss_function(tar_real, predictions)
gradients = tape.gradient(loss, self.transformer.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.transformer.trainable_variables))
self.train_loss(loss)
self.train_accuracy(tar_real, predictions)
@tf.function
def test_step(self, inp, labels):
predictions = self.predict(inp)
t_loss = self.loss_object(labels, predictions)
self.test_loss(t_loss)
self.test_accuracy(labels, predictions)
def train(self, train_dataset):
for epoch in range(params['epochs']):
start = time.time()
self.train_loss.reset_states()
self.train_accuracy.reset_states()
# inp -> portuguese, tar -> english
for (batch, (inp, tar)) in enumerate(train_dataset):
self.train_step(inp, tar)
if batch % 50 == 0:
print('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, batch, self.train_loss.result(), self.train_accuracy.result()))
if (epoch + 1) % 5 == 0:
ckpt_save_path = self.ckpt_manager.save()
print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,ckpt_save_path))
print('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, self.train_loss.result(), self.train_accuracy.result()))
print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
# 評估
def predict(self, inp_sentence):
start_token = [tokenizer_pt.vocab_size]
end_token = [tokenizer_pt.vocab_size + 1]
# 輸入語句是葡萄牙語,增加開始和結束標記
inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
encoder_input = tf.expand_dims(inp_sentence, 0)
# 因爲目標是英語,輸入 transformer 的第一個詞應該是
# 英語的開始標記。
decoder_input = [tokenizer_en.vocab_size]
output = tf.expand_dims(decoder_input, 0)
for i in range(MAX_LENGTH):
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
encoder_input, output)
# predictions.shape == (batch_size, seq_len, vocab_size)
predictions, attention_weights = self.transformer(encoder_input,
output,
False,
enc_padding_mask,
combined_mask,
dec_padding_mask)
# 從 seq_len 維度選擇最後一個詞
predictions = predictions[:, -1:, :] # (batch_size, 1, vocab_size)
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# 如果 predicted_id 等於結束標記,就返回結果
if predicted_id == tokenizer_en.vocab_size + 1:
return tf.squeeze(output, axis=0), attention_weights
# 連接 predicted_id 與輸出,作爲解碼器的輸入傳遞到解碼器。
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0)