背景介紹
文本分類是NLP中的常見的重要任務之一,它的主要功能就是將輸入的文本以及文本的類別訓練出一個模型,使之具有一定的泛化能力,能夠對新文本進行較好地預測。它的應用很廣泛,在很多領域發揮着重要作用,例如垃圾郵件過濾、輿情分析以及新聞分類等。
現階段的文本分類模型頻出,種類繁多,花樣百變,既有機器學習中的樸素貝葉斯模型、SVM等,也有深度學習中的各種模型,比如經典的CNN, RNN,以及它們的變形,如CNN-LSTM,還有各種高大上的Attention模型。
無疑,文本分類是一個相對比較成熟的任務,我們儘可以選擇自己喜歡的模型來完成該任務。本文以kashgari-tf爲例,它能夠支持各種文本分類模型,比如BiLSTM,CNN_LSTM,AVCNN等,且對預訓練模型,比如BERT的支持較好,它能讓我們輕鬆地完成文本分類任務。
下面,讓我們一起走進文本分類的世界,分分鐘搞定text classification!
項目
首先,我們需要找一份數據作爲例子。我們選擇THUCNews,THUCNews是根據新浪新聞RSS訂閱頻道2005~2011年間的歷史數據篩選過濾生成,包含74萬篇新聞文檔(2.19 GB),均爲UTF-8純文本格式。我們在原始新浪新聞分類體系的基礎上,從中選擇10個候選分類類別:體育、娛樂、家居、房產、教育、時尚、時政、遊戲、科技、財經。
數據總量一共爲6.5萬條,其中訓練集數據5萬條,每個類別5000條,驗證集數據0.5萬條,每個類別500條,測試集數據1萬條,每個類別1000條。筆者已將數據放在Github上,讀者可以在最後的總結中找到。
項目結構,如下圖:
接着,我們嘗試着利用kashgari-tf來訓練一個文本分類模型,其中模型我們採用CNN-LSTM,完整的Python代碼(text_classification_model_train.py)如下:
# -*- coding: utf-8 -*-
# time: 2019-08-13 11:16
# place: Pudong Shanghai
from kashgari.tasks.classification import CNN_LSTM_Model
# 獲取數據集
def load_data(data_type):
with open('./data/cnews.%s.txt' % data_type, 'r', encoding='utf-8') as f:
content = [_.strip() for _ in f.readlines() if _.strip()]
x, y = [], []
for line in content:
label, text = line.split(maxsplit=1)
y.append(label)
x.append([_ for _ in text])
return x, y
# 獲取數據
train_x, train_y = load_data('train')
valid_x, valid_y = load_data('val')
test_x, test_y = load_data('test')
# 訓練模型
model = CNN_LSTM_Model()
model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=5)
# 評估模型
model.evaluate(test_x, test_y)
# 保存模型
model.save('text_classification_model')
輸出的模型結果如下:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) (None, 2544) 0
_________________________________________________________________
layer_embedding (Embedding) (None, 2544, 100) 553200
_________________________________________________________________
conv1d (Conv1D) (None, 2544, 32) 9632
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 1272, 32) 0
_________________________________________________________________
cu_dnnlstm (CuDNNLSTM) (None, 100) 53600
_________________________________________________________________
dense (Dense) (None, 10) 1010
=================================================================
Total params: 617,442
Trainable params: 617,442
Non-trainable params: 0
設定模型訓練次數爲5個epoch,batch_size爲16。模型訓練完後,在訓練集、驗證集上的結果如下:
數據集 | accuracy | loss |
---|---|---|
訓練集 | 0.9661 | 0.1184 |
驗證集 | 0.9204 | 0.2567 |
在測試集上的結果如下:
precision recall f1-score support
體育 0.9852 0.9970 0.9911 1000
娛樂 0.9938 0.9690 0.9813 1000
家居 0.9384 0.8830 0.9098 1000
房產 0.9490 0.9680 0.9584 1000
教育 0.9650 0.8820 0.9216 1000
時尚 0.9418 0.9710 0.9562 1000
時政 0.9732 0.9450 0.9589 1000
遊戲 0.9454 0.9700 0.9576 1000
科技 0.8910 0.9560 0.9223 1000
財經 0.9566 0.9920 0.9740 1000
accuracy 0.9533 10000
macro avg 0.9539 0.9533 0.9531 10000
weighted avg 0.9539 0.9533 0.9531 10000
總的來說,上述模型訓練的效果還是很不錯的。接下來,是考驗模型的預測能力的時刻了,看看它是否具體文本分類的泛化能力。
測試
我們已經有了訓練好的模型text_classification_model
,接着讓我們利用該模型來對新的數據進行預測,預測的代碼(model_predict.py)如下:
# -*- coding: utf-8 -*-
# time: 2019-08-14 00:21
# place: Pudong Shanghai
import kashgari
# 加載模型
loaded_model = kashgari.utils.load_model('text_classification_model')
text = '華夏幸福成立於 1998 年,前身爲廊坊市華夏房地產開發有限公司,初始註冊資本 200 萬元,其中王文學出資 160 萬元,廊坊市融通物資貿易有限公司出資 40 萬元,後經多次股權轉讓和增資,公司於 2007 年整體改製爲股份制公司,2011 年完成借殼上市。'
x = [[_ for _ in text]]
label = loaded_model.predict(x)
print('預測分類:%s' % label)
以下是測試結果:
原文1: 華夏幸福成立於 1998 年,前身爲廊坊市華夏房地產開發有限公司,初始註冊資本 200 萬元,其中王文學出資 160 萬元,廊坊市融通物資貿易有限公司出資 40 萬元,後經多次股權轉讓和增資,公司於 2007 年整體改製爲股份制公司,2011 年完成借殼上市。
分類結果:預測分類:['財經']
原文2: 現今常見的短袖襯衫大致上可以分爲:夏威夷襯衫、古巴襯衫、保齡球衫,三者之間雖有些微分別,但其實有些時候,一件襯衫也可能包含了多種款式的特色。而‘古巴(領)襯衫’最顯而易見的特點在於‘領口’,通常會設計爲V領,且呈現微微的外翻,也因此缺少襯衫領口常見的‘第一顆鈕釦’,衣服到領子的剪裁爲一體成形,整體較寬鬆舒適。
分類結果:預測分類:['時尚']
原文3:周琦2014年加盟新疆廣匯籃球俱樂部,當年就代表俱樂部青年隊接連拿下全國籃球青年聯賽冠軍和全國俱樂部青年聯賽冠軍。升入一隊後,周琦2016年隨隊出戰第25屆亞冠杯,獲得冠軍。2016-2017賽季,周琦爲新疆廣匯隊奪得隊史首座總冠軍獎盃立下汗馬功勞,他在總決賽中帶傷出戰,更是傳爲佳話。
分類結果:預測分類:['體育']
原文4: 周杰倫[微博]監製賽車電影《叱吒風雲》13日釋出花絮導演篇,不僅真實賽車競速畫面大量曝光,幾十輛百萬賽車在國際專業賽道、山路飆速,場面浩大震撼,更揭開不少
現場拍攝的幕後畫面。監製周杰倫在現場與導演討論劇本、范逸臣[微博]與高英軒大打出手、甚至有眼尖網友發現在花絮中閃過“男神”李玉璽[微博]的畫面。
分類結果:預測分類:['娛樂']
原文5: 北京時間8月13日上午消息,據《韓國先驅報》網站報道,近日美國知識產權所有者協會( Intellectual Property Owners Association)發佈的一份報告顯示,在獲得的
美國專利數量方面,IBM、微軟和通用電氣等美國企業名列前茅,排在後面的韓國科技巨頭三星、LG與之競爭激烈。
分類結果:預測分類:['科技']
總結
雖然我們上述測試的文本分類效果還不錯,但也存在着一些分類錯誤的情況。
本文講述瞭如何利用kashgari-tf模塊來快速地搭建文本分類任務,其實,也沒那麼難!
本文代碼和數據及已上傳至Github, 網址爲:
https://github.com/percent4/cnews_text_classification
注意:不妨瞭解下筆者的微信公衆號: Python爬蟲與算法(微信號爲:easy_web_scrape), 歡迎大家關注~