轉載公衆號:https://mp.weixin.qq.com/s/0k71fKKv2SRLv9M6BjDo4w
原創: 盛源車 機器學習算法與自然語言處理 1周前
作者:盛源車
知乎專欄:魔法抓的學習筆記
五分鐘看懂seq2seq attention模型。
本文通過圖片,詳細地畫出了seq2seq+attention模型的全部流程,幫助小夥伴們無痛理解機器翻譯等任務的重要模型。
seq2seq 是一個Encoder–Decoder 結構的網絡,它的輸入是一個序列,輸出也是一個序列, Encoder 中將一個可變長度的信號序列變爲固定長度的向量表達,Decoder 將這個固定長度的向量變成可變長度的目標的信號序列。--簡書
好了別管了,接下來開始刷圖吧。
大框架
想象一下翻譯任務,input是一段英文,output是一段中文。
公式(直接跳過看圖最佳)
詳細圖
左側爲Encoder+輸入,右側爲Decoder+輸出。中間爲Attention。
從左邊Encoder開始,輸入轉換爲word embedding, 進入LSTM。LSTM會在每一個時間點上輸出hidden states。如圖中的h1,h2,...,h8。
接下來進入右側Decoder,輸入爲(1) 句首 <sos>符號,原始context vector(爲0),以及從encoder最後一個hidden state: h8。LSTM的是輸出是一個hidden state。(當然還有cell state,這裏沒用到,不提。)
Decoder的hidden state與Encoder所有的hidden states作爲輸入,放入Attention模塊開始計算一個context vector。之後會介紹attention的計算方法。
下一個時間點
來到時間點2,之前的context vector可以作爲輸入和目標的單詞串起來作爲lstm的輸入。之後又回到一個hiddn state。以此循環。
另一方面,context vector和decoder的hidden state合起來通過一系列非線性轉換以及softmax最後計算出概率。
在luong中提到了三種score的計算方法。這裏圖解前兩種:
Attention score function: dot
輸入是encoder的所有hidden states H: 大小爲(hid dim, sequence length)。decoder在一個時間點上的hidden state, s: 大小爲(hid dim, 1)。
第一步:旋轉H爲(sequence length, hid dim) 與s做點乘得到一個 大小爲(sequence length, 1)的分數。
第二步:對分數做softmax得到一個合爲1的權重。
第三步:將H與第二步得到的權重做點乘得到一個大小爲(hid dim, 1)的context vector。
Attention score function: general
輸入是encoder的所有hidden states H: 大小爲(hid dim1, sequence length)。decoder在一個時間點上的hidden state, s: 大小爲(hid dim2, 1)。此處兩個hidden state的緯度並不一樣。
第一步:旋轉H爲(sequence length, hid dim1) 與 Wa [大小爲 hid dim1, hid dim 2)] 做點乘, 再和s做點乘得到一個 大小爲(sequence length, 1)的分數。
第二步:對分數做softmax得到一個合爲1的權重。
第三步:將H與第二步得到的權重做點乘得到一個大小爲(hid dim, 1)的context vector。
完結
看懂一個模型的最好辦法就是在心裏想一遍從輸入到模型到輸出每一個步驟裏,tensor是如何流動的。