三個訓練改進:
- 去掉下一句預測(NSP)任務
- 動態掩碼。BERT 依賴隨機掩碼和預測 token。原版的 BERT 實現在數據預處理期間執行一次掩碼,得到一個靜態掩碼。 而 RoBERTa 使用了動態掩碼:每次向模型輸入一個序列時都會生成新的掩碼模式。這樣,在大量數據不斷輸入的過程中,模型會逐漸適應不同的掩碼策略,學習不同的語言表徵。
- 文本編碼。Byte-Pair Encoding(BPE)是字符級和詞級別表徵的混合,支持處理自然語言語料庫中的衆多常見詞彙。原版的 BERT 實現使用字符級別的 BPE 詞彙,大小爲 30K,是在利用啓發式分詞規則對輸入進行預處理之後學得的。Facebook 研究者沒有采用這種方式,而是考慮用更大的 byte 級別 BPE 詞彙表來訓練 BERT,這一詞彙表包含 50K 的 subword 單元,且沒有對輸入作任何額外的預處理或分詞。
爲什麼動態掩碼比靜態掩碼更加好?
原來Bert對每一個序列隨機選擇15%的Tokens替換成[MASK],爲了消除與下游任務的不匹配,還對這15%的Tokens進行(1)80%的時間替換成[MASK];(2)10%的時間不變;(3)10%的時間替換成其他詞。但整個訓練過程,這15%的Tokens一旦被選擇就不再改變,也就是說從一開始隨機選擇了這15%的Tokens,之後的N個epoch裏都不再改變了。這就叫做靜態Masking。
而RoBERTa一開始把預訓練的數據複製10份,每一份都隨機選擇15%的Tokens進行Masking,也就是說,同樣的一句話有10種不同的mask方式。然後每份數據都訓練N/10個epoch。這就相當於在這N個epoch的訓練中,每個序列的被mask的tokens是會變化的。這就叫做動態Masking。
爲什麼要用到文本編碼?
動態掩碼改靜態掩碼
roberta的mask代碼:
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
rng.shuffle(cand_indexes)
output_tokens = [t[2:] if len(re.findall('##[\u4E00-\u9FA5]', t))>0 else t for t in tokens]
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index][2:] if len(re.findall('##[\u4E00-\u9FA5]', tokens[index]))>0 else tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
# tf.logging.info('%s' % (tokens))
# tf.logging.info('%s' % (output_tokens))
return (output_tokens, masked_lm_positions, masked_lm_labels)
bert的mask創造方式
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indexes.append(i)
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(masked_lm(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
理解文本編碼對英文的有作用,中文還是需要詞級別的word piece