ernie是百度開發的一個預訓練模型,模型結構類似bert,沒錯就是ernie成功帶我入paddle
ernie1.0在bert的基礎上mask做了改進如圖,後續bert也使用這種方法比如wwm的模型:
ernie2.0在bert的基礎上增加了更多的訓練任務任務,目前還沒開源
以下個人觀點:
這些任務首先大致是pipeline的方法在各個任務上訓練到最好,然後使用多任務訓練的方法一起進行微調
ernie模型結構也是transformer應該和bert一樣,只是在訓練模型時增加了更多的任務。
paddle已經將很多預訓練模型封裝的很好的,調用也非常方便,但是一開始這個高度封裝會有點暈,本文儘量從零開始,來跑一個ernie的分類模型,後續會介紹封裝後模型的調用以及模型的部署。
還是三個流程:
1.數據處理(將數據變成可以放入到模型的格式)
2.模型構建(構建你想使用的模型)
3.訓練和評估模型
1.數據處理
首先同樣需要定義一個reader來讀取數據和生成數據
這個部分和其它分類模型的區別是需要處理成ernie需要的格式
context部分包括:
1.token_ids :文本轉成index,這裏需要注意ernie會提供自己的字典,不需要自己生成字典,可以調用ernie自帶的tokenization.py中的convert_tokens_to_ids函數來生成
tokenization.py文件包括token化和convert_tokens_to_ids
vocab_file=r'/ernie/vocab.txt' #字典文件
full_tokenize=FullTokenizer(vocab_file)
tokens=full_tokenize.tokenize("我出生於1960年,湖南人")
print(tokens)
print(full_tokenize.convert_tokens_to_ids(tokens))
2.text_type_ids:輸入文本格式,一句話全是0,句子對就是[0,0,…0,1,1…,1]
3.position_ids:句子的絕對位置,代碼如下(看代碼更好懂)
position_ids = list(range(len(token_ids)))
以上部分就是模型需要的內容了,ernie2.0還要增加一個task_id(後續了)
ernie的reader做的工作比較多(這裏細講分類模型的reader):
首先reader會有一個BaseReader 這個是覆蓋了分類,序列標註等任務的共同的基本操作主要包括:
1.讀取文件
def csv_reader(fd, delimiter='\t'):
def gen():
for i in fd:
slots = i.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen()
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="utf8") as f:
reader = csv_reader(f, delimiter="\t")
headers = next(reader) #[label,text_a]
Example = namedtuple('Example', headers) #建立映射標籤
examples = []
for line in reader:
example = Example(*line)
examples.append(example)
return examples
examples格式如下:
[Example(label=‘1’, text_a=‘去 逛街 咯’)]
2.將example裝換成record
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
"""Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a)
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
if "text_b" in example._fields:
text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT/ERNIE is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
text_type_ids = []
tokens.append("[CLS]")
text_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
text_type_ids.append(0)
tokens.append("[SEP]")
text_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
text_type_ids.append(1)
tokens.append("[SEP]")
text_type_ids.append(1)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
qid = None
if "qid" in example._fields:
qid = example.qid
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
return record
3.對數據進行pad並生成batch