NLP(十六)輕鬆上手文本分類

背景介紹

  文本分類是NLP中的常見的重要任務之一,它的主要功能就是將輸入的文本以及文本的類別訓練出一個模型,使之具有一定的泛化能力,能夠對新文本進行較好地預測。它的應用很廣泛,在很多領域發揮着重要作用,例如垃圾郵件過濾、輿情分析以及新聞分類等。
  現階段的文本分類模型頻出,種類繁多,花樣百變,既有機器學習中的樸素貝葉斯模型、SVM等,也有深度學習中的各種模型,比如經典的CNN, RNN,以及它們的變形,如CNN-LSTM,還有各種高大上的Attention模型。
  無疑,文本分類是一個相對比較成熟的任務,我們儘可以選擇自己喜歡的模型來完成該任務。本文以kashgari-tf爲例,它能夠支持各種文本分類模型,比如BiLSTM,CNN_LSTM,AVCNN等,且對預訓練模型,比如BERT的支持較好,它能讓我們輕鬆地完成文本分類任務。
  下面,讓我們一起走進文本分類的世界,分分鐘搞定text classification!

項目

  首先,我們需要找一份數據作爲例子。我們選擇THUCNews,THUCNews是根據新浪新聞RSS訂閱頻道2005~2011年間的歷史數據篩選過濾生成,包含74萬篇新聞文檔(2.19 GB),均爲UTF-8純文本格式。我們在原始新浪新聞分類體系的基礎上,從中選擇10個候選分類類別:體育、娛樂、家居、房產、教育、時尚、時政、遊戲、科技、財經。
  數據總量一共爲6.5萬條,其中訓練集數據5萬條,每個類別5000條,驗證集數據0.5萬條,每個類別500條,測試集數據1萬條,每個類別1000條。筆者已將數據放在Github上,讀者可以在最後的總結中找到。
  項目結構,如下圖:

  接着,我們嘗試着利用kashgari-tf來訓練一個文本分類模型,其中模型我們採用CNN-LSTM,完整的Python代碼(text_classification_model_train.py)如下:

# -*- coding: utf-8 -*-
# time: 2019-08-13 11:16
# place: Pudong Shanghai

from kashgari.tasks.classification import CNN_LSTM_Model

# 獲取數據集
def load_data(data_type):
    with open('./data/cnews.%s.txt' % data_type, 'r', encoding='utf-8') as f:
        content = [_.strip() for _ in f.readlines() if _.strip()]

    x, y = [], []
    for line in content:
        label, text = line.split(maxsplit=1)
        y.append(label)
        x.append([_ for _ in text])

    return x, y

# 獲取數據
train_x, train_y = load_data('train')
valid_x, valid_y = load_data('val')
test_x, test_y = load_data('test')

# 訓練模型
model = CNN_LSTM_Model()
model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=5)

# 評估模型
model.evaluate(test_x, test_y)

# 保存模型
model.save('text_classification_model')

輸出的模型結果如下:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input (InputLayer)           (None, 2544)              0
_________________________________________________________________
layer_embedding (Embedding)  (None, 2544, 100)         553200
_________________________________________________________________
conv1d (Conv1D)              (None, 2544, 32)          9632
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 1272, 32)          0
_________________________________________________________________
cu_dnnlstm (CuDNNLSTM)       (None, 100)               53600
_________________________________________________________________
dense (Dense)                (None, 10)                1010
=================================================================
Total params: 617,442
Trainable params: 617,442
Non-trainable params: 0

  設定模型訓練次數爲5個epoch,batch_size爲16。模型訓練完後,在訓練集、驗證集上的結果如下:

數據集 accuracy loss
訓練集 0.9661 0.1184
驗證集 0.9204 0.2567

在測試集上的結果如下:

             precision    recall  f1-score   support

          體育     0.9852    0.9970    0.9911      1000
          娛樂     0.9938    0.9690    0.9813      1000
          家居     0.9384    0.8830    0.9098      1000
          房產     0.9490    0.9680    0.9584      1000
          教育     0.9650    0.8820    0.9216      1000
          時尚     0.9418    0.9710    0.9562      1000
          時政     0.9732    0.9450    0.9589      1000
          遊戲     0.9454    0.9700    0.9576      1000
          科技     0.8910    0.9560    0.9223      1000
          財經     0.9566    0.9920    0.9740      1000

    accuracy                         0.9533     10000
   macro avg     0.9539    0.9533    0.9531     10000
weighted avg     0.9539    0.9533    0.9531     10000

  總的來說,上述模型訓練的效果還是很不錯的。接下來,是考驗模型的預測能力的時刻了,看看它是否具體文本分類的泛化能力。

測試

  我們已經有了訓練好的模型text_classification_model,接着讓我們利用該模型來對新的數據進行預測,預測的代碼(model_predict.py)如下:

# -*- coding: utf-8 -*-
# time: 2019-08-14 00:21
# place: Pudong Shanghai

import kashgari

# 加載模型
loaded_model = kashgari.utils.load_model('text_classification_model')

text = '華夏幸福成立於 1998 年,前身爲廊坊市華夏房地產開發有限公司,初始註冊資本 200 萬元,其中王文學出資 160 萬元,廊坊市融通物資貿易有限公司出資 40 萬元,後經多次股權轉讓和增資,公司於 2007 年整體改製爲股份制公司,2011 年完成借殼上市。'

x = [[_ for _ in text]]

label = loaded_model.predict(x)
print('預測分類:%s' % label)

以下是測試結果:

原文1: 華夏幸福成立於 1998 年,前身爲廊坊市華夏房地產開發有限公司,初始註冊資本 200 萬元,其中王文學出資 160 萬元,廊坊市融通物資貿易有限公司出資 40 萬元,後經多次股權轉讓和增資,公司於 2007 年整體改製爲股份制公司,2011 年完成借殼上市。
分類結果:預測分類:['財經']

原文2: 現今常見的短袖襯衫大致上可以分爲:夏威夷襯衫、古巴襯衫、保齡球衫,三者之間雖有些微分別,但其實有些時候,一件襯衫也可能包含了多種款式的特色。而‘古巴(領)襯衫’最顯而易見的特點在於‘領口’,通常會設計爲V領,且呈現微微的外翻,也因此缺少襯衫領口常見的‘第一顆鈕釦’,衣服到領子的剪裁爲一體成形,整體較寬鬆舒適。
分類結果:預測分類:['時尚']

原文3:周琦2014年加盟新疆廣匯籃球俱樂部,當年就代表俱樂部青年隊接連拿下全國籃球青年聯賽冠軍和全國俱樂部青年聯賽冠軍。升入一隊後,周琦2016年隨隊出戰第25屆亞冠杯,獲得冠軍。2016-2017賽季,周琦爲新疆廣匯隊奪得隊史首座總冠軍獎盃立下汗馬功勞,他在總決賽中帶傷出戰,更是傳爲佳話。
分類結果:預測分類:['體育']

原文4: 周杰倫[微博]監製賽車電影《叱吒風雲》13日釋出花絮導演篇,不僅真實賽車競速畫面大量曝光,幾十輛百萬賽車在國際專業賽道、山路飆速,場面浩大震撼,更揭開不少
現場拍攝的幕後畫面。監製周杰倫在現場與導演討論劇本、范逸臣[微博]與高英軒大打出手、甚至有眼尖網友發現在花絮中閃過“男神”李玉璽[微博]的畫面。
分類結果:預測分類:['娛樂']

原文5: 北京時間8月13日上午消息,據《韓國先驅報》網站報道,近日美國知識產權所有者協會( Intellectual Property Owners Association)發佈的一份報告顯示,在獲得的
美國專利數量方面,IBM、微軟和通用電氣等美國企業名列前茅,排在後面的韓國科技巨頭三星、LG與之競爭激烈。
分類結果:預測分類:['科技']

總結

  雖然我們上述測試的文本分類效果還不錯,但也存在着一些分類錯誤的情況。
  本文講述瞭如何利用kashgari-tf模塊來快速地搭建文本分類任務,其實,也沒那麼難!
  本文代碼和數據及已上傳至Github, 網址爲:
https://github.com/percent4/cnews_text_classification

注意:不妨瞭解下筆者的微信公衆號: Python爬蟲與算法(微信號爲:easy_web_scrape), 歡迎大家關注~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章