論文:Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

Transformers具有學習長期依賴的潛力,但在語言模型的設置中受到固定長度上下文的限制。Transformer-XL引入了兩點創新——循環機制(Recurrence Mechanism)相對位置編碼(Relative Positional Encoding),不僅可以捕獲長距離依賴性,還可以解決上下文碎片問題。Transformer-XL學習的依賴性比RNN長80%,比vanilla Transformers長450%,在短序列和長序列上都能獲得更好的性能,並且在評估過程中比vanilla Transformers快1800倍。

 

循環機制(Segment-Level Recurrence with State Reuse)

循環機制的目標是通過利用之前段的信息來實現長期依賴性。與vanilla Transformer類似,Transformer-XL處理第一個標記段,但它會保留隱藏層的輸出。處理後面的段時,每個隱藏層都會接收兩個輸入:

  1. 該段的前一個隱藏層的輸出,和vanilla Transformer相同(如下圖中的灰色箭頭所示)。
  2. 上一個段的隱藏層的輸出(如綠色箭頭所示),可以使模型創建長期依賴關係。

其中,\tau表示第幾段,n表示第幾層,h表示隱層的輸出。SG(\cdot )表示停止計算梯度,[h_{u}\circ h_{v}]表示在長度維度上的兩個隱層的拼接,W是模型參數。與Transformer唯一關鍵的不同就在於Key和Value矩陣的計算上,即k_{\tau +1}^{n}v_{\tau +1}^{n},它們基於的是擴展後的上下文隱層狀態\tilde{h}_{\tau +1}^{n-1}進行計算,h_{\tau }^{n-1}是之前段的緩存。

從技術上講,這兩個輸入會被拼接,然後用於計算當前段的Key和Value矩陣。該步驟爲網絡提供了更多關於每個表徵的權重(重要性)的信息,但它不會更改Value矩陣。

                                       圖:Transformer-XL語言模型的訓練和測試示意。來源:Transformer-XL
該概念可以擴展到更長的依賴上。使用相同的方法,利用前面多個段的信息,只要GPU內存允許,在測試階段也可以獲得更長的依賴。

循環機制的另一個優點是其測試速度快。在每個步驟中,它可以一次前進一整個段(而不是像在vanilla Transformer中一次只能前進一個表徵),並使用先前段的數據來預測當前段的表徵。
 

相對位置編碼

循環機制引入了新的挑戰——原始位置編碼將每個段分開處理,因此,來自不同段的表徵會具有相同的位置編碼。例如,第一和第二段的第一個表徵將具有相同的編碼,雖然它們的位置和重要性並不相同(比如第一個段中的第一個表徵可能重要性低一些)。這種混淆可能會錯誤地影響網絡。

針對此問題,論文提出了一種新的位置編碼方式。這種位置編碼是每個注意力模塊的一部分。它不會僅在第一層之前編碼位置,而且會基於表徵之間的相對距離而非絕對位置進行編碼。

在Transformer中,第一層的計算查詢q_{i}^{T}和鍵k_{j}之間的attention分數的方式爲:

其中,E_{x_{i}}是詞i的embedding,E_{x_{j}}是詞j的embedding,U_{i}U_{j}是位置向量,這個式子實際上是(W_{q}(E_{x_{i}}+U_{i}))^{T}\cdot (W_{k}(E_{x_{j}}+U_{j}))的展開,就是Transformer中的標準格式。
在Transformer-XL中,對上述的attention計算方式進行了變換,轉爲相對位置的計算,而且不僅僅在第一層這麼計算,在每一層都是這樣計算。

從技術上講,它對注意力頭分數(Attention Head’s Score)的計算方式不再是簡單的乘法(Qi⋅Kj),而是包括四個部分:

  1. 內容權重——沒有添加原始位置編碼的原始分數。
  2. 相對於當前內容的位置偏差(Qi)。該項使用正弦類函數來計算表徵之間的相對距離(例如i-j),用以替代當前表徵的絕對位置。
  3. 可學習的全局內容偏差u——用於調整其他表徵內容(Kj)的重要性。
  4. 可學習的全局位置偏差v——僅根據表徵之間的距離調整重要性(例如,最後一個詞可能比前一段中的詞更重要)。
     

Transformer-XL模型的整體計算公式整理如下

 

multihead attention 代碼如下

def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
                       n_head, d_head, dropout, dropatt, is_training,
                       kernel_initializer, scope='rel_attn'):
  scale = 1 / (d_head ** 0.5)
  with tf.variable_scope(scope):
    qlen = tf.shape(w)[0]
    rlen = tf.shape(r)[0]
    bsz = tf.shape(w)[1]

    cat = tf.concat([mems, w],
                    0) if mems is not None and mems.shape.ndims > 1 else w
    w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
                              kernel_initializer=kernel_initializer, name='qkv')
    r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
                               kernel_initializer=kernel_initializer, name='r')

    w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
    w_head_q = w_head_q[-qlen:]

    klen = tf.shape(w_head_k)[0]

    w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
    w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
    w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])

    r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])

    rw_head_q = w_head_q + r_w_bias
    rr_head_q = w_head_q + r_r_bias

    AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
    BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
    BD = rel_shift(BD)

    attn_score = (AC + BD) * scale
    attn_mask_t = attn_mask[:, :, None, None]
    attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t

    attn_prob = tf.nn.softmax(attn_score, 1)
    attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)

    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
    size_t = tf.shape(attn_vec)
    attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])

    attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
                               kernel_initializer=kernel_initializer, name='o')
    attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)

    output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
  return output

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章