Transformers具有學習長期依賴的潛力,但在語言模型的設置中受到固定長度上下文的限制。Transformer-XL引入了兩點創新——循環機制(Recurrence Mechanism)和相對位置編碼(Relative Positional Encoding),不僅可以捕獲長距離依賴性,還可以解決上下文碎片問題。Transformer-XL學習的依賴性比RNN長80%,比vanilla Transformers長450%,在短序列和長序列上都能獲得更好的性能,並且在評估過程中比vanilla Transformers快1800倍。
循環機制(Segment-Level Recurrence with State Reuse)
循環機制的目標是通過利用之前段的信息來實現長期依賴性。與vanilla Transformer類似,Transformer-XL處理第一個標記段,但它會保留隱藏層的輸出。處理後面的段時,每個隱藏層都會接收兩個輸入:
- 該段的前一個隱藏層的輸出,和vanilla Transformer相同(如下圖中的灰色箭頭所示)。
- 上一個段的隱藏層的輸出(如綠色箭頭所示),可以使模型創建長期依賴關係。
其中,表示第幾段,表示第幾層,表示隱層的輸出。表示停止計算梯度,表示在長度維度上的兩個隱層的拼接,是模型參數。與Transformer唯一關鍵的不同就在於Key和Value矩陣的計算上,即和,它們基於的是擴展後的上下文隱層狀態進行計算,是之前段的緩存。
從技術上講,這兩個輸入會被拼接,然後用於計算當前段的Key和Value矩陣。該步驟爲網絡提供了更多關於每個表徵的權重(重要性)的信息,但它不會更改Value矩陣。
圖:Transformer-XL語言模型的訓練和測試示意。來源:Transformer-XL
該概念可以擴展到更長的依賴上。使用相同的方法,利用前面多個段的信息,只要GPU內存允許,在測試階段也可以獲得更長的依賴。
循環機制的另一個優點是其測試速度快。在每個步驟中,它可以一次前進一整個段(而不是像在vanilla Transformer中一次只能前進一個表徵),並使用先前段的數據來預測當前段的表徵。
相對位置編碼
循環機制引入了新的挑戰——原始位置編碼將每個段分開處理,因此,來自不同段的表徵會具有相同的位置編碼。例如,第一和第二段的第一個表徵將具有相同的編碼,雖然它們的位置和重要性並不相同(比如第一個段中的第一個表徵可能重要性低一些)。這種混淆可能會錯誤地影響網絡。
針對此問題,論文提出了一種新的位置編碼方式。這種位置編碼是每個注意力模塊的一部分。它不會僅在第一層之前編碼位置,而且會基於表徵之間的相對距離而非絕對位置進行編碼。
在Transformer中,第一層的計算查詢和鍵之間的attention分數的方式爲:
其中,是詞的embedding,是詞的embedding,和是位置向量,這個式子實際上是的展開,就是Transformer中的標準格式。
在Transformer-XL中,對上述的attention計算方式進行了變換,轉爲相對位置的計算,而且不僅僅在第一層這麼計算,在每一層都是這樣計算。
從技術上講,它對注意力頭分數(Attention Head’s Score)的計算方式不再是簡單的乘法(Qi⋅Kj),而是包括四個部分:
- 內容權重——沒有添加原始位置編碼的原始分數。
- 相對於當前內容的位置偏差(Qi)。該項使用正弦類函數來計算表徵之間的相對距離(例如i-j),用以替代當前表徵的絕對位置。
- 可學習的全局內容偏差u——用於調整其他表徵內容(Kj)的重要性。
- 可學習的全局位置偏差v——僅根據表徵之間的距離調整重要性(例如,最後一個詞可能比前一段中的詞更重要)。
Transformer-XL模型的整體計算公式整理如下
multihead attention 代碼如下
def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
n_head, d_head, dropout, dropatt, is_training,
kernel_initializer, scope='rel_attn'):
scale = 1 / (d_head ** 0.5)
with tf.variable_scope(scope):
qlen = tf.shape(w)[0]
rlen = tf.shape(r)[0]
bsz = tf.shape(w)[1]
cat = tf.concat([mems, w],
0) if mems is not None and mems.shape.ndims > 1 else w
w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
kernel_initializer=kernel_initializer, name='qkv')
r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
kernel_initializer=kernel_initializer, name='r')
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
w_head_q = w_head_q[-qlen:]
klen = tf.shape(w_head_k)[0]
w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
rw_head_q = w_head_q + r_w_bias
rr_head_q = w_head_q + r_r_bias
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
BD = rel_shift(BD)
attn_score = (AC + BD) * scale
attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
attn_prob = tf.nn.softmax(attn_score, 1)
attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
size_t = tf.shape(attn_vec)
attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
kernel_initializer=kernel_initializer, name='o')
attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
return output