本文檔介紹瞭如何使用BERT實現多類別文本分類任務,適合稍微瞭解BERT和文本分類的同學參考。
(一) 下載
首先,在github上clone谷歌的BERT項目,或者直接下載。項目地址
然後,下載中文預訓練模型,地址
(二) 環境準備
tensorflow >= 1.11.0
注意:
- 在GPU上運行Tensorflow,需要CUDA版本和Tensorflow版本的對應。比如Tensorflow-1.11.0最高只能使用9.0版本的CUDA,否則加載時會出現找不到libcublas.so的錯誤。
- 安裝TensorFlow時,如果出現無法卸載enum34的錯誤,可以用pip install *** --ignore_installed enum34命令先跳過。
(三) 數據準備
準備數據集,包括訓練集、驗證集、測試集,格式相同,每行爲一個類別+文本,用“\t”間隔。(如果選擇其他間隔符,需要修改run_classifier.py中_read_tsv方法)。
我做的是新聞文本分類,數據格式如下:
(四) 修改run_classifier.py文件
- 添加處理數據集的類,class ZbsProcessor(DataProcessor),分別實現以下方法:
def get_train_examples(self, data_dir): 讀取訓練集
def get_dev_examples(self, data_dir): 讀取驗證集
def get_test_examples(self, data_dir): 讀取測試集
def get_labels(self, labels): 獲得類別集合
def _create_examples(self, lines, set_type): 生成訓練和驗證樣本
- 修改main函數。在第744行,將ZbsProcessor添加到processors中
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"zbs": ZbsProcessor
}
- 原代碼中,先判斷是否train,然後獲取訓練樣本,但是後面需要所有類別,所以需要改成先獲取所有類別,然後判斷判斷是否train。即代碼:
if FLAGS.do_train:
train_examples = processor.get_train_examples(FLAGS.data_dir)
改爲:
train_examples,train_labels,temp=processor.get_train_examples(FLAGS.data_dir)
label_list = processor.get_labels(train_labels)
if FLAGS.do_train:
- 修改運行參數。可以直接在代碼裏修改,也可以執行.py文件時提供參數。參數意義:
data_dir:存放數據集的文件夾
bert_config_file:bert中文模型中的bert_config.json文件
task_name:processors中添加的任務名“zbs”
vocab_file:bert中文模型中的vocab.txt文件
output_dir:訓練好的分類器模型的存放文件夾
init_checkpoint:bert中文模型中的bert_model.ckpt.index文件
do_train:是否訓練,設置爲“True”
do_eval:是否驗證,設置爲“True”
do_predict:是否測試,設置爲“False”
可調參數:
max_seq_length:輸入文本序列的最大長度,也就是每個樣本的最大處理長度,多餘會去掉,不夠會補齊。最大值512。
train_batch_size: 訓練模型求梯度時,批量處理數據集的大小。值越大,訓練速度越快,內存佔用越多。
eval_batch_size: 驗證時,批量處理數據集的大小。同上。
predict_batch_size: 測試時,批量處理數據集的大小。同上。
learning_rate: 反向傳播更新權重時,步長大小。值越大,訓練速度越快。值越小,訓練速度越慢,收斂速度慢,
容易過擬合。遷移學習中,一般設置較小的步長(小於2e-4)
num_train_epochs:所有樣本完全訓練一遍的次數。
warmup_proportion:用於warmup的訓練集的比例。
save_checkpoints_steps:檢查點的保存頻率。
(五) 運行。
如果在文件中已經設置後參數,直接運行即可。
也可以在執行.py文件時,傳入參數,例如:
python zbs_classifier.py --data_dir=/home/hls/bert_zbs_data/data2c-11
--init_checkpoint=/home/hls/bert_zbs_data/data2c-11/out1/model.ckpt-3616.index
--output_dir=/home/hls/bert_zbs_data/data2c-11/out3
--max_seq_length=256
--learning_rate=2e-5
--num_train_epochs=50
(六) 附錄代碼
class ZbsProcessor(DataProcessor):
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self, labels):
return set(labels)
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
labels = []
labels_test = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
if set_type == "test":
label = "臺灣" #這裏要設置成數據集中一個真實的類別
else:
label = tokenization.convert_to_unicode(line[0])
labels.append(label)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples, labels, labels_test
(七)代碼結構
class InputExample(object):
class InputFeatures(object):
class DataProcessor(object):
class ZbsProcessor(DataProcessor):
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
將單條訓練數據,由 class InputExample 結構轉換成 class InputFeature 的結構
def file_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_file):
遍歷訓練樣本,將其轉換成InputFeatures特徵,並保存到train.TFRecord文件中。調用convert_single_example()方法實現單條數據轉換。
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
根據保存的訓練文件train.TFRecord,生成tf.data.TFRecordDataset用於提供給Estimator來訓練。
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_one_hot_embeddings):
返回tf.contrib.tpu.TPUEstimatorSpec對象。
(八) 有趣的ISSUES
- 如何在訓練時輸出loss
logging_hook = tf.train.LoggingTensorHook({“loss”: total_loss}, every_n_iter=10)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
training_hooks=[logging_hook],
scaffold_fn=scaffold_fn)