Bert (Bi-directional Encoder Representations from Transformers) Pytorch 源碼解讀(二)

前言

這裏是 Bert(Bi-directional Encoder Representations from Transformers) 源碼解讀的第二部分,第一部分主要介紹了 bert_model.py 文件中, bert 模型的定義。而第二部分爲 BERT_Training.py 文件,該部分源碼主要實現了 Bert 模型的預訓練工作。


Bert 源碼解讀:

1. 模型結構源碼: bert_model.py

2. 模型預訓練源碼:bert_training.py

3. 數據預處理源碼:wiki_dataset.py


在開始前,先大致的介紹一下 bert 模型的預訓練。bert 的預訓練過程是個 Multi Task learning 的過程。其中同時進行的兩個任務分別爲:

1. Masked Language Modeling。

2. Next sentece Predict(Classification)。

Masked Language Mode(MLM)選擇輸入序列中的隨機token樣本,並用特殊的token[MASK]替換。MLM的目標是預測遮擋token時的交叉熵損失。BERT一致選擇15%的輸入token作爲可能的替換。在所選的token中,80%替換爲[MASK], 10%保持不變,10%替換爲隨機選擇的詞彙表token。MLM 任務在 bert 中的主要任務是建立語言模型,他的原理和 Word2Vec 中的 CBOW 模型以及 Negative Sampling 算法思想類似。

Next Sentece Predict(NSP)是一種二分類損失,用於預測兩個片段在原文中是否相互跟隨。通過從文本語料庫中提取連續的句子來創建積極的例子。反例是通過對來自不同文檔的段進行配對來創建的。正、負樣本的抽樣概率相等。NSP的目標是爲了提高下游任務的性能,比如自然語言推理,這需要對句子對之間的關係進行推理。


開始

1. Import & Config

from torch.utils.data import DataLoader
from dataset.wiki_dataset import BERTDataset
from models.bert_model import *
import tqdm
import pandas as pd
import numpy as np
import os


config = {}
config["train_corpus_path"] = "./pretraining_data/wiki_dataset/test_wiki.txt"
config["test_corpus_path"] = "./pretraining_data/wiki_dataset/test_wiki.txt"
config["word2idx_path"] = "./pretraining_data/wiki_dataset/bert_word2idx_extend.json"
config["output_path"] = "./output_wiki_bert"

config["batch_size"] = 1
config["max_seq_len"] = 200
config["vocab_size"] = 32162
config["lr"] = 2e-6
config["num_workers"] = 0

首先導入各種所需要的庫,之前定義好的 Bert 模型,以及數據處理模塊,其次對一些參數及路徑進行設置。參數分別有:batch_sizemax_seq_len 最大序列長度,vocab_size 字典大小,lr 學習率,num_workers 加載數據時的線程數。

2. Pretrain

class Pretrainer:
    def __init__(self, bert_model,
                 vocab_size,
                 max_seq_len,
                 batch_size,
                 lr,
                 with_cuda=True,
                 ):
        # 詞量, 注意在這裏實際字(詞)匯量 = vocab_size - 20,
        # 因爲前20個token用來做一些特殊功能, 如padding等等
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        # 學習率
        self.lr = lr
        # 是否使用GPU
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")
        # 限定的單句最大長度
        self.max_seq_len = max_seq_len
        # 初始化超參數的配置
        bertconfig = BertConfig(vocab_size=config["vocab_size"])
        # 初始化bert模型
        self.bert_model = bert_model(config=bertconfig)
        self.bert_model.to(self.device)
        # 初始化訓練數據集
        train_dataset = BERTDataset(corpus_path=config["train_corpus_path"],
                                    word2idx_path=config["word2idx_path"],
                                    seq_len=self.max_seq_len,
                                    hidden_dim=bertconfig.hidden_size,
                                    on_memory=False,
                                    )
        # 初始化訓練dataloader
        self.train_dataloader = DataLoader(train_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=config["num_workers"],
                                           collate_fn=lambda x: x)
        # 初始化測試數據集
        test_dataset = BERTDataset(corpus_path=config["test_corpus_path"],
                                   word2idx_path=config["word2idx_path"],
                                   seq_len=self.max_seq_len,
                                   hidden_dim=bertconfig.hidden_size,
                                   on_memory=True,
                                   )
        # 初始化測試dataloader
        self.test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size,
                                          num_workers=config["num_workers"],
                                          collate_fn=lambda x: x)
        # 初始化positional encoding
        self.positional_enc = self.init_positional_encoding(hidden_dim=bertconfig.hidden_size,
                                                            max_seq_len=self.max_seq_len)
        # 拓展positional encoding的維度爲[1, max_seq_len, hidden_size]
        self.positional_enc = torch.unsqueeze(self.positional_enc, dim=0)

        # 列舉需要優化的參數並傳入優化器
        optim_parameters = list(self.bert_model.parameters())
        self.optimizer = torch.optim.Adam(optim_parameters, lr=self.lr)

        print("Total Parameters:", sum([p.nelement() for p in self.bert_model.parameters()]))

__init__ 方法中,主要進行 bert 預訓練前的一些準備工作,包括:定義字典大小、batch size、learning rate、GPU的指定、最大句子長度、以及 Bert 的超參設置。同時對訓練數據、測試數據、positional encoding進行初始化,定義optimizer。這部分代碼很簡單,每一部分都在註釋中進行了說明,這裏就不再進行一一說明。

3.Positional encoding

    def init_positional_encoding(self, hidden_dim, max_seq_len):
        position_enc = np.array([
            [pos / np.power(10000, 2 * i / hidden_dim) for i in range(hidden_dim)]
            if pos != 0 else np.zeros(hidden_dim) for pos in range(max_seq_len)])

        position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
        position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
        denominator = np.sqrt(np.sum(position_enc**2, axis=1, keepdims=True))
        position_enc = position_enc / (denominator + 1e-8)
        position_enc = torch.from_numpy(position_enc).type(torch.FloatTensor)
        return position_enc

Positional encoding 的初始化代碼,這裏使用的 Positional encoding 的初始化方式與傳統 Transformer 相同,通過 sin 與 cos 函數生成固定的值,具體計算方法爲:

 這裏需要注意的是,在 Google 官方開源的 TensorFlow 版本的 bert 源碼中,Positional encoding 是使用的跟字向量相同的方法初始化,經過模型訓練得到的,而這裏使用了傳統 Transformer 的直接計算方法。

4. load_model & train

    def load_model(self, model, dir_path="./output"):
        # 加載模型
        checkpoint_dir = self.find_most_recent_state_dict(dir_path)
        checkpoint = torch.load(checkpoint_dir)
        model.load_state_dict(checkpoint["model_state_dict"], strict=False)
        torch.cuda.empty_cache()
        model.to(self.device)
        print("{} loaded for training!".format(checkpoint_dir))

    def train(self, epoch, df_path="./output_wiki_bert/df_log.pickle"):
        self.bert_model.train()
        self.iteration(epoch, self.train_dataloader, train=True, df_path=df_path)

load_model 用於模型的加載,train 方法對模型進行迭代訓練。

5. NSP loss & MLM loss

    def compute_loss(self, predictions, labels, num_class=2, ignore_index=None):
        if ignore_index is None:
            loss_func = CrossEntropyLoss()
        else:
            loss_func = CrossEntropyLoss(ignore_index=ignore_index)
        return loss_func(predictions.view(-1, num_class), labels.view(-1))

    def get_mlm_accuracy(self, predictions, labels):
        predictions = torch.argmax(predictions, dim=-1, keepdim=False)
        mask = (labels > 0).to(self.device)
        mlm_accuracy = torch.sum((predictions == labels) * mask).float()
        mlm_accuracy /= (torch.sum(mask).float() + 1e-8)
        return mlm_accuracy.item()

compute_loss 方法計算 NSP 任務的 loss 值,get_mlm_accuracy 計算 MLM 任務的 loss 值。

6. Padding

    def padding(self, output_dic_lis):
        bert_input = [i["bert_input"] for i in output_dic_lis]
        bert_label = [i["bert_label"] for i in output_dic_lis]
        segment_label = [i["segment_label"] for i in output_dic_lis]
        bert_input = torch.nn.utils.rnn.pad_sequence(bert_input, batch_first=True)
        bert_label = torch.nn.utils.rnn.pad_sequence(bert_label, batch_first=True)
        segment_label = torch.nn.utils.rnn.pad_sequence(segment_label, batch_first=True)
        is_next = torch.cat([i["is_next"] for i in output_dic_lis])
        return {"bert_input": bert_input,
                "bert_label": bert_label,
                "segment_label": segment_label,
                "is_next": is_next}

對輸入的數據進行 Padding 補齊,在訓練數據初始化時,通過處理數據模塊已近將輸入序列的最大長度截斷至設置好的max_seq_len 長度。而這裏進行的是 Padding 補齊的操作。

7. iteration

    def iteration(self, epoch, data_loader, train=True, df_path="./output_wiki_bert/df_log.pickle"):
        if not os.path.isfile(df_path) and epoch != 0:
            raise RuntimeError("log DataFrame path not found and can't create a new one because we're not training from scratch!")
        if not os.path.isfile(df_path) and epoch == 0:
            df = pd.DataFrame(columns=["epoch", "train_next_sen_loss", "train_mlm_loss",
                                       "train_next_sen_acc", "train_mlm_acc",
                                       "test_next_sen_loss", "test_mlm_loss",
                                       "test_next_sen_acc", "test_mlm_acc"
                                       ])
            df.to_pickle(df_path)
            print("log DataFrame created!")

        str_code = "train" if train else "test"

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        total_next_sen_loss = 0
        total_mlm_loss = 0
        total_next_sen_acc = 0
        total_mlm_acc = 0
        total_element = 0

        for i, data in data_iter:
            # print('IDX of data_iter:', i)
            data = self.padding(data)
            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}
            positional_enc = self.positional_enc[:, :data["bert_input"].size()[-1], :].to(self.device)

            # 1. forward the next_sentence_prediction and masked_lm model
            mlm_preds, next_sen_preds = self.bert_model.forward(input_ids=data["bert_input"],
                                                                positional_enc=positional_enc,
                                                                token_type_ids=data["segment_label"])

            mlm_acc = self.get_mlm_accuracy(mlm_preds, data["bert_label"])
            next_sen_acc = next_sen_preds.argmax(dim=-1, keepdim=False).eq(data["is_next"]).sum().item()
            mlm_loss = self.compute_loss(mlm_preds, data["bert_label"], self.vocab_size, ignore_index=0)
            next_sen_loss = self.compute_loss(next_sen_preds, data["is_next"])
            loss = mlm_loss + next_sen_loss

            # 3. backward and optimization only in train
            if train:
                self.optimizer.zero_grad()
                loss.backward()
                # for param in self.model.parameters():
                #     print(param.grad.data.sum())
                self.optimizer.step()

            total_next_sen_loss += next_sen_loss.item()
            total_mlm_loss += mlm_loss.item()
            total_next_sen_acc += next_sen_acc
            total_element += data["is_next"].nelement()
            total_mlm_acc += mlm_acc

            if train:
                log_dic = {
                    "epoch": epoch,
                   "train_next_sen_loss": total_next_sen_loss / (i + 1),
                   "train_mlm_loss": total_mlm_loss / (i + 1),
                   "train_next_sen_acc": total_next_sen_acc / total_element,
                   "train_mlm_acc": total_mlm_acc / (i + 1),
                   "test_next_sen_loss": 0, "test_mlm_loss": 0,
                   "test_next_sen_acc": 0, "test_mlm_acc": 0
                }

            else:
                log_dic = {
                    "epoch": epoch,
                   "test_next_sen_loss": total_next_sen_loss / (i + 1),
                   "test_mlm_loss": total_mlm_loss / (i + 1),
                   "test_next_sen_acc": total_next_sen_acc / total_element,
                   "test_mlm_acc": total_mlm_acc / (i + 1),
                   "train_next_sen_loss": 0, "train_mlm_loss": 0,
                   "train_next_sen_acc": 0, "train_mlm_acc": 0
                }

            if i % 10 == 0:
                data_iter.write(str({k: v for k, v in log_dic.items() if v != 0 and k != "epoch"}))

        if train:
            df = pd.read_pickle(df_path)
            df = df.append([log_dic])
            df.reset_index(inplace=True, drop=True)
            df.to_pickle(df_path)
        else:
            log_dic = {k: v for k, v in log_dic.items() if v != 0 and k != "epoch"}
            df = pd.read_pickle(df_path)
            df.reset_index(inplace=True, drop=True)
            for k, v in log_dic.items():
                df.at[epoch, k] = v
            df.to_pickle(df_path)
            return float(log_dic["test_next_sen_loss"])+float(log_dic["test_mlm_loss"])

這部分代碼主要爲模型預訓的每個 epoch 的迭代過程,以及訓練中 log 的記錄。

源碼中,首先對要儲存的 log 進行定義,然後將 total loss 和 accuracy 置零,其次就是計算每個 batch 的 mlm loss 與 nsp loss ,在進行反向傳播,更新參數的過程。

最後記錄每個 epoch 的 訓練信息。

8. 模型保存

    def save_state_dict(self, model, epoch, dir_path="./output", file_path="bert.model"):
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)
        save_path = dir_path+ "/" + file_path + ".epoch.{}".format(str(epoch))
        model.to("cpu")
        torch.save({"model_state_dict": model.state_dict()}, save_path)
        print("{} saved!".format(save_path))
        model.to(self.device)

這部分源碼也比較簡單,主要是模型的保存。

9. 訓練代碼

    def init_trainer(dynamic_lr, load_model=False):
        trainer = Pretrainer(BertForPreTraining,
                             vocab_size=config["vocab_size"],
                             max_seq_len=config["max_seq_len"],
                             batch_size=config["batch_size"],
                             lr=dynamic_lr,
                             with_cuda=True)
        if load_model:
            trainer.load_model(trainer.bert_model, dir_path=config["output_path"])
        return trainer


    start_epoch = 3
    train_epoches = 1
    trainer = init_trainer(config["lr"], load_model=True)
    # if train from scratch
    all_loss = []
    threshold = 0
    patient = 10
    best_f1 = 0
    dynamic_lr = config["lr"]
    for epoch in range(start_epoch, start_epoch + train_epoches):
        print("train with learning rate {}".format(str(dynamic_lr)))
        trainer.train(epoch)

        trainer.save_state_dict(trainer.bert_model, epoch, dir_path=config["output_path"],
                                file_path="bert.model")
        trainer.test(epoch)

這就是預訓練的代碼了,init_trainer 實例化前面定義的 Pretraniner 類,然後就是每個 epoch 調用類內的 train 方法來進行訓練了,每個 epoch 保存一個模型,並進行測試。


總結

以上就是 Pytorch 版本 Bert 模型的預訓練部分的全部源碼了,相較於 TensorFlow 版本還是略顯冗長,另外這裏的 Positional encoding 部分是使用的傳統 Transformer 中,直接計算得出的方式,與 Google 官方給出的源碼不同。官方源碼中將 Positional encoding 也作爲參數,加入了訓練,這會導致一定量的參數增加,但在如此大規模的訓練數據面前,一點點參數的增加也就不值一提了,而究竟哪種 encoding 方式帶來的效果更好,我自己還沒有進行調研,暫未得出結論。

 

如有問題歡迎指正,轉載請註明出處。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章