Bert代碼解讀記錄

代碼學習的是前一篇博客中pytorch的代碼的BertForTokenClassification模型,run的是ner例子:https://github.com/huggingface/transformers/blob/master/examples/run_ner.py

1、模型概覽:

使用的模型是:multi_cased_L-12_H-768_A-12.zip,https://github.com/google-research/bert/blob/master/README.md

可以看出模型是用transformer中的encoder堆積而成的,一共12層,最後一層是分類層,會對每個token預測ner數據中所有label的概率。

2、輸入:

閱讀代碼時,設置的batchsize=3,爲了方便觀看數據,將模型需要用到的輸入數據都列出部分。基本有

  • input_ids: token的索引,後面直接用來找到token對應的embedding。
  • attention_mask:用來標記token是否padding出來的。(在計算attention的時候,他們的值會變成0.實現方法是pad出來的token的attention得分加上-10000. 注意區別代碼裏attention得分和attention概率,attention概率是attention得分乘上Value之後的結果。)
  • token_type_ids:標記輸入的token是第一句話的token還是第二句話

3、主體流程:

通過input_ids找到embedding(這個embedding是:

embeddings = inputs_embeds + position_embeddings(位置embedding) + token_type_embeddings(句子embedding)

已經算好的結果),然後順圖中標記的順序看,返回過程沒有用線連了。

具體過程分開理:(配合下面的貼圖使用)

首先

是1、10、11標記的圖片,是模型的主體。embedding通過encoder拿到token的表示(10中的output[0]),過了一個線性分類器得到每個token對應的label的概率分佈,再利用attention_mask索引出未pading的token,最後的loss使用交叉熵直接算未pading的token。(這裏並不是訓練bert,所以並不是預測mask的token,這裏也沒mask的token了)

然後

主要關注encoder:主要是圖中的2、3、4、5、6、7、8、9幾個圖。

  1. 首先embedding通過2、3、4、5過一個QKV的attention,具體實現可以看5,其中有mask吊pad的token的操作
  2. 接着attention出來後過了一個dense層,一個dropout層,Norm層,其中有殘差連接,圖6。
  3. 然後又過了一個dense層和gelu激活,gelu實現見圖8。
  4. 繼續過了一個dense+dropout+Norm+殘差連接四連擊,見圖9

上圖的過程是從代碼中看出的,和論文 《attention is all you need》中的圖完全一致:代碼中堆了11層紅框的結構。

 

附:

一開始一直沒弄懂爲什麼bert是雙向的,因爲MLM(mask language model),訓練的時候根據上下文預測mask,上下文體現出了雙向的意思。

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章