源代碼在我的github:
https://github.com/danan0755/Bert_Classifier
頭條新聞數據下載鏈接:
鏈接:https://pan.baidu.com/s/1P9G8pl4B78aPL-wtFTKf0g
提取碼:4hm9
預訓練模型使用chinese_roberta_wwm_ext_L-12_H-768_A-12
下載鏈接:
鏈接:https://pan.baidu.com/s/11iqTeuja63leVkTDO5dy8g
提取碼:es0h
# bert模型設置
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) # 加載預訓練模型
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,))
x2_in = Input(shape=(None,))
x = bert_model([x1_in, x2_in])
x = Lambda(lambda x: x[:, 0])(x) # 取出[CLS]對應的向量用來做分類
p = Dense(15, activation='softmax')(x)
model = Model([x1_in, x2_in], p)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=Adam(1e-5),
metrics=['accuracy'])
model.summary()
train_D = data_generator(train_data)
valid_D = data_generator(valid_data)
5個epoch後
loss: 0.1223 - accuracy: 0.9654 - val_loss: 0.9093 - val_accuracy: 0.8500