roberta跟bert的對比

roberta到底改進了什麼?

三個訓練改進:

  • 去掉下一句預測(NSP)任務
  • 動態掩碼。BERT 依賴隨機掩碼和預測 token。原版的 BERT 實現在數據預處理期間執行一次掩碼,得到一個靜態掩碼。 而 RoBERTa 使用了動態掩碼:每次向模型輸入一個序列時都會生成新的掩碼模式。這樣,在大量數據不斷輸入的過程中,模型會逐漸適應不同的掩碼策略,學習不同的語言表徵。
  • 文本編碼。Byte-Pair Encoding(BPE)是字符級和詞級別表徵的混合,支持處理自然語言語料庫中的衆多常見詞彙。原版的 BERT 實現使用字符級別的 BPE 詞彙,大小爲 30K,是在利用啓發式分詞規則對輸入進行預處理之後學得的。Facebook 研究者沒有采用這種方式,而是考慮用更大的 byte 級別 BPE 詞彙表來訓練 BERT,這一詞彙表包含 50K 的 subword 單元,且沒有對輸入作任何額外的預處理或分詞。

爲什麼動態掩碼比靜態掩碼更加好?

原來Bert對每一個序列隨機選擇15%的Tokens替換成[MASK],爲了消除與下游任務的不匹配,還對這15%的Tokens進行(1)80%的時間替換成[MASK];(2)10%的時間不變;(3)10%的時間替換成其他詞。但整個訓練過程,這15%的Tokens一旦被選擇就不再改變,也就是說從一開始隨機選擇了這15%的Tokens,之後的N個epoch裏都不再改變了。這就叫做靜態Masking。

而RoBERTa一開始把預訓練的數據複製10份,每一份都隨機選擇15%的Tokens進行Masking,也就是說,同樣的一句話有10種不同的mask方式。然後每份數據都訓練N/10個epoch。這就相當於在這N個epoch的訓練中,每個序列的被mask的tokens是會變化的。這就叫做動態Masking。

爲什麼要用到文本編碼?

動態掩碼改靜態掩碼

roberta的mask代碼:

def create_masked_lm_predictions(tokens, masked_lm_prob,
                                 max_predictions_per_seq, vocab_words, rng):
    """Creates the predictions for the masked LM objective."""

    cand_indexes = []
    for (i, token) in enumerate(tokens):
        if token == "[CLS]" or token == "[SEP]":
            continue
        # Whole Word Masking means that if we mask all of the wordpieces
        # corresponding to an original word. When a word has been split into
        # WordPieces, the first token does not have any marker and any subsequence
        # tokens are prefixed with ##. So whenever we see the ## token, we
        # append it to the previous set of word indexes.
        #
        # Note that Whole Word Masking does *not* change the training code
        # at all -- we still predict each WordPiece independently, softmaxed
        # over the entire vocabulary.
        if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
                token.startswith("##")):
            cand_indexes[-1].append(i)
        else:
            cand_indexes.append([i])

    rng.shuffle(cand_indexes)

    output_tokens = [t[2:] if len(re.findall('##[\u4E00-\u9FA5]', t))>0 else t for t in tokens]

    num_to_predict = min(max_predictions_per_seq,
                         max(1, int(round(len(tokens) * masked_lm_prob))))

    masked_lms = []
    covered_indexes = set()
    for index_set in cand_indexes:
        if len(masked_lms) >= num_to_predict:
            break
        # If adding a whole-word mask would exceed the maximum number of
        # predictions, then just skip this candidate.
        if len(masked_lms) + len(index_set) > num_to_predict:
            continue
        is_any_index_covered = False
        for index in index_set:
            if index in covered_indexes:
                is_any_index_covered = True
                break
        if is_any_index_covered:
            continue
        for index in index_set:
            covered_indexes.add(index)

            masked_token = None
            # 80% of the time, replace with [MASK]
            if rng.random() < 0.8:
                masked_token = "[MASK]"
            else:
                # 10% of the time, keep original
                if rng.random() < 0.5:
                    masked_token = tokens[index][2:] if len(re.findall('##[\u4E00-\u9FA5]', tokens[index]))>0 else tokens[index]
                # 10% of the time, replace with random word
                else:
                    masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

            output_tokens[index] = masked_token

            masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
    assert len(masked_lms) <= num_to_predict
    masked_lms = sorted(masked_lms, key=lambda x: x.index)

    masked_lm_positions = []
    masked_lm_labels = []
    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)

    # tf.logging.info('%s' % (tokens))
    # tf.logging.info('%s' % (output_tokens))
    return (output_tokens, masked_lm_positions, masked_lm_labels)

bert的mask創造方式

def create_masked_lm_predictions(tokens, masked_lm_prob,
                                 max_predictions_per_seq, vocab_words, rng):
  """Creates the predictions for the masked LM objective."""

  cand_indexes = []
  for (i, token) in enumerate(tokens):
    if token == "[CLS]" or token == "[SEP]":
      continue
    cand_indexes.append(i)

  rng.shuffle(cand_indexes)

  output_tokens = list(tokens)

  masked_lm = collections.namedtuple("masked_lm", ["index", "label"])  # pylint: disable=invalid-name

  num_to_predict = min(max_predictions_per_seq,
                       max(1, int(round(len(tokens) * masked_lm_prob))))

  masked_lms = []
  covered_indexes = set()
  for index in cand_indexes:
    if len(masked_lms) >= num_to_predict:
      break
    if index in covered_indexes:
      continue
    covered_indexes.add(index)

    masked_token = None
    # 80% of the time, replace with [MASK]
    if rng.random() < 0.8:
      masked_token = "[MASK]"
    else:
      # 10% of the time, keep original
      if rng.random() < 0.5:
        masked_token = tokens[index]
      # 10% of the time, replace with random word
      else:
        masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

    output_tokens[index] = masked_token

    masked_lms.append(masked_lm(index=index, label=tokens[index]))

  masked_lms = sorted(masked_lms, key=lambda x: x.index)

  masked_lm_positions = []
  masked_lm_labels = []
  for p in masked_lms:
    masked_lm_positions.append(p.index)
    masked_lm_labels.append(p.label)

  return (output_tokens, masked_lm_positions, masked_lm_labels)

理解文本編碼對英文的有作用,中文還是需要詞級別的word piece

BPE的學習鏈接

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章