自己訓練BERT
本文使用google提供的BERT腳本進行訓練,並在TensorBoard中觀察BERT的計算圖。
bert地址:
https://github.com/google-research/bert
clone這個git repo
git clone https://github.com/google-research/bert
下載BERT預訓練模型,裏面有vocab.txt文件,後面要用到
wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
解壓BERT預訓練模型到某個目錄
:~/ugetdownload$ unzip uncased_L-12_H-768_A-12.zip
Archive: uncased_L-12_H-768_A-12.zip
creating: uncased_L-12_H-768_A-12/
inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta
inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001
inflating: uncased_L-12_H-768_A-12/vocab.txt
inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index
inflating: uncased_L-12_H-768_A-12/bert_config.json
bert腳本需要使用tensorflow 1.x運行,使用2.x會報錯
conda create -n py37tf1 python=3.7
conda activate py37tf1
pip install tensorflow < 2.0
設置BERT_BASE_DIR環境變量
export BERT_BASE_DIR=~/ugetdownload/uncased_L-12_H-768_A-12
運行數據腳本,產生用於預訓練的數據
(py37tf1) ~/code/github_read/google-research/bert$ python create_pretraining_data.py \
> --input_file=./sample_text.txt \
> --output_file=/tmp/tf_examples.tfrecord \
> --vocab_file=$BERT_BASE_DIR/vocab.txt \
> --do_lower_case=True \
> --max_seq_length=128 \
> --max_predictions_per_seq=20 \
> --masked_lm_prob=0.15 \
> --random_seed=12345 \
> --dupe_factor=5
WARNING:tensorflow:From create_pretraining_data.py:469: The name tf.app.run is deprecated. Please use tf.compat.v1.app.run instead.
WARNING:tensorflow:From create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.
W0502 17:39:58.054978 139793997326144 module_wrapper.py:139] From create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.
WARNING:tensorflow:From create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.
W0502 17:39:58.055087 139793997326144 module_wrapper.py:139] From create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.
在同一目錄下運行訓練腳本
python run_pretraining.py \
--input_file=/tmp/tf_examples.tfrecord \
--output_dir=/tmp/pretraining_output \
--do_train=True \
--do_eval=True \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=20 \
--num_warmup_steps=10 \
--learning_rate=2e-5
I0502 17:48:02.704535 140404206638912 run_pretraining.py:173] name = cls/seq_relationship/output_bias:0, shape = (2,), *INIT_FROM_CKPT*
WARNING:tensorflow:From run_pretraining.py:198: The name tf.metrics.accuracy is deprecated. Please use tf.compat.v1.metrics.accuracy instead.
W0502 17:48:02.709678 140404206638912 module_wrapper.py:139] From run_pretraining.py:198: The name tf.metrics.accuracy is deprecated. Please use tf.compat.v1.metrics.accuracy instead.
WARNING:tensorflow:From run_pretraining.py:202: The name tf.metrics.mean is deprecated. Please use tf.compat.v1.metrics.mean instead.
W0502 17:48:02.722177 140404206638912 module_wrapper.py:139] From run_pretraining.py:202: The name tf.metrics.mean is deprecated. Please use tf.compat.v1.metrics.mean instead.
INFO:tensorflow:Done calling model_fn.
I0502 17:48:02.767565 140404206638912 estimator.py:1150] Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-05-02T17:48:02Z
I0502 17:48:02.787446 140404206638912 evaluation.py:255] Starting evaluation at 2020-05-02T17:48:02Z
INFO:tensorflow:Graph was finalized.
I0502 17:48:03.247700 140404206638912 monitored_session.py:240] Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/pretraining_output/model.ckpt-20
INFO:tensorflow:Evaluation [100/100]
I0502 17:50:20.824831 140404206638912 evaluation.py:167] Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2020-05-02-17:50:20
I0502 17:50:20.975484 140404206638912 evaluation.py:275] Finished evaluation at 2020-05-02-17:50:20
INFO:tensorflow:Saving dict for global step 20: global_step = 20, loss = 0.27436933, masked_lm_accuracy = 0.95210946, masked_lm_loss = 0.273851, next_sentence_accuracy = 1.0, next_sentence_loss = 0.0004196863
I0502 17:50:20.975750 140404206638912 estimator.py:2049] Saving dict for global step 20: global_step = 20, loss = 0.27436933, masked_lm_accuracy = 0.95210946, masked_lm_loss = 0.273851, next_sentence_accuracy = 1.0, next_sentence_loss = 0.0004196863
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/pretraining_output/model.ckpt-20
I0502 17:50:21.689311 140404206638912 estimator.py:2109] Saving 'checkpoint_path' summary for global step 20: /tmp/pretraining_output/model.ckpt-20
INFO:tensorflow:evaluation_loop marked as finished
I0502 17:50:21.689822 140404206638912 error_handling.py:101] evaluation_loop marked as finished
INFO:tensorflow:***** Eval results *****
I0502 17:50:21.689958 140404206638912 run_pretraining.py:483] ***** Eval results *****
INFO:tensorflow: global_step = 20
I0502 17:50:21.690061 140404206638912 run_pretraining.py:485] global_step = 20
INFO:tensorflow: loss = 0.27436933
I0502 17:50:21.690387 140404206638912 run_pretraining.py:485] loss = 0.27436933
INFO:tensorflow: masked_lm_accuracy = 0.95210946
I0502 17:50:21.690463 140404206638912 run_pretraining.py:485] masked_lm_accuracy = 0.95210946
INFO:tensorflow: masked_lm_loss = 0.273851
I0502 17:50:21.690540 140404206638912 run_pretraining.py:485] masked_lm_loss = 0.273851
INFO:tensorflow: next_sentence_accuracy = 1.0
I0502 17:50:21.690627 140404206638912 run_pretraining.py:485] next_sentence_accuracy = 1.0
INFO:tensorflow: next_sentence_loss = 0.0004196863
I0502 17:50:21.690728 140404206638912 run_pretraining.py:485] next_sentence_loss = 0.0004196863
使用TensorBoard可視化
現在基礎的訓練能夠跑通,我們使用TensorBoard來可視化BERT的訓練過程。
由於BERT使用了TensorBoard的estimator api,默認就會產生TensorBoard所需的events文件,文件位置在output_dir參數所指定的位置(/tmp/pretraining_output
)。
(base) :/tmp/pretraining_output$ ll
總用量 2610064
drwxr-xr-x 3 wenkai wenkai 4096 5月 2 17:50 ./
drwxrwxrwt 65 root root 12288 5月 2 18:07 ../
-rw-rw-r-- 1 wenkai wenkai 126 5月 2 17:48 checkpoint
drwxr-xr-x 2 wenkai wenkai 4096 5月 2 17:50 eval/
-rw-rw-r-- 1 wenkai wenkai 156 5月 2 17:50 eval_results.txt
-rw-rw-r-- 1 wenkai wenkai 13311481 5月 2 17:48 events.out.tfevents.1588412530.G6
-rw-rw-r-- 1 wenkai wenkai 9153045 5月 2 17:42 graph.pbtxt
-rw-rw-r-- 1 wenkai wenkai 1321277144 5月 2 17:42 model.ckpt-0.data-00000-of-00001
-rw-rw-r-- 1 wenkai wenkai 23350 5月 2 17:42 model.ckpt-0.index
-rw-rw-r-- 1 wenkai wenkai 3796855 5月 2 17:42 model.ckpt-0.meta
-rw-rw-r-- 1 wenkai wenkai 1321277144 5月 2 17:48 model.ckpt-20.data-00000-of-00001
-rw-rw-r-- 1 wenkai wenkai 23350 5月 2 17:48 model.ckpt-20.index
-rw-rw-r-- 1 wenkai wenkai 3796855 5月 2 17:48 model.ckpt-20.meta
打開TensorBoard即可可視化訓練過程:
(py37tf1) wenkai@wenkai-HP-EliteBook-840-G6:/tmp/pretraining_output$ tensorboard --logdir . --port 6007
W0502 18:08:16.384034 139671431165696 plugin_event_accumulator.py:294] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events. Overwriting the graph with the newest event.
W0502 18:08:16.389085 139671431165696 plugin_event_accumulator.py:302] Found more than one metagraph event per run. Overwriting the metagraph with the newest event.
TensorBoard 1.15.0 at http://0.0.00:6007/ (Press CTRL+C to quit)
常見報錯
使用TF2.x運行報錯。
Traceback (most recent call last):
File "create_pretraining_data.py", line 26, in <module>
flags = tf.flags
AttributeError: module 'tensorflow' has no attribute 'flags'
解決辦法:使用TensorFlow 1.x
沒有設置BERT_BASE_DIR環境變量
Traceback (most recent call last):
File "create_pretraining_data.py", line 469, in <module>
tf.app.run()
File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "create_pretraining_data.py", line 440, in main
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
File "/home/wenkai/code/github_read/google-research/bert/tokenization.py", line 165, in __init__
self.vocab = load_vocab(vocab_file)
File "/home/wenkai/code/github_read/google-research/bert/tokenization.py", line 127, in load_vocab
token = convert_to_unicode(reader.readline())
File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/tensorflow_core/python/lib/io/file_io.py", line 178, in readline
self._preread_check()
File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/tensorflow_core/python/lib/io/file_io.py", line 84, in _preread_check
compat.as_bytes(self.__name), 1024 * 512)
tensorflow.python.framework.errors_impl.NotFoundError: /vocab.txt; No such file or directory
解決辦法:正確設置BERT_BASE_DIR
關於TensorBoard的使用(請忽略)
由於使用的tensorflow是1.x版本,TensorBoard也要使用1.x版本,相關的文檔在這裏:
https://github.com/tensorflow/tensorboard/blob/master/docs/r1/summaries.md
https://github.com/tensorflow/tensorboard/blob/master/docs/r1/graphs.md
https://github.com/tensorflow/tensorboard/blob/master/docs/r1/overview.md
重點在這:
The FileWriter takes a logdir in its constructor - this logdir is quite important, it's the directory where all of the events will be written out. Also, the FileWriter can optionally take a Graph in its constructor. If it receives a Graph object, then TensorBoard will visualize your graph along with tensor shape information.
https://github.com/tensorflow/tensorboard/blob/master/docs/r1/summaries.md