HAN論文模型代碼復現與重構

論文簡介

本文主要介紹CMU在2016年發表在ACL的一篇論文:Hierarchical Attention Networks for Document Classification及其代碼復現。

該論文是用於文檔級情感分類(document-level sentiment classification)的,其模型架構如下:
在這裏插入圖片描述
該模型稱爲層次注意力模型(Hierarchical Attention Network),根據作者所述,

  • 層次是指:句子由單詞組成,文檔由句子組成,據此可以構建一個自下而上的層次結構。

  • 注意力是指:組成某個句子的單詞對該句子的情感傾向的貢獻是不同的,通常來說,形容詞的貢獻(如good)就比名詞(如book)更大;同理,組成文檔的句子對該文檔的情感傾向的貢獻也不同,例如某些句子可能僅僅是陳述事實,而另一些句子則很明顯地表達出了自己的觀點。據此,作者提出使用注意力機制來挖掘句子和文檔中對情感分類比較重要的部分(btw,注意力機制比較成熟的最早的應用是Google發表的Attention is All you Need一文中)。

對詞嵌入進行編解碼的方式無非是雙向GRU或CNN等,此處不再贅述。需要注意的是,該模型中的注意力機制分爲兩個部分,分別是word attention和sentence attention,即分別在單詞和句子上應用注意力機制,可視化結果如下:
在這裏插入圖片描述
可以看出注意力機制的可視化結果高亮出了情感極性比較強的單詞。例如左邊帶delicious的評論預測結果爲4分(較好),帶terrible的評論的預測結果爲0分(極差)。
由此也可說明注意力機制是有效的。

代碼復現及重構

(顯然這篇幾年前的論文的代碼不是我寫的)
本文參考了github上對該模型的復現代碼:textClassifier,源代碼就不詳細解釋了,稍有複雜的也就是數據處理部分,源碼實現將訓練data設爲三維的,並在詞嵌入後餵給了HAN模型。

考慮到源代碼結構不是很清晰,也無法自定義輸入的詞嵌入的維度和訓練數據集,因此本文對該代碼進行了重構。

首先說明Python版本和依賴的庫:

Python >= 3.6
numpy
pandas
re
bs4
pickle
sklearn
gensim
nltk
keras
tensorflow

Python版本需要大於3.6,至於其他庫的話,只要版本不太落後一般都沒問題

下面詳細介紹改動的部分。

參數選項

原文沒有提供參數選項,如果要輸入不同維度的詞嵌入文件,則每次都要修改源代碼,十分不便,爲此, 我在重構時加入了參數選項,主要代碼如下:

parser = argparse.ArgumentParser('HAN')
parser.add_argument('--full_data_path', '-d', 
				help='Full path of  data', default=FULL_DATA_PATH)
parser.add_argument('--processed_pickle_data_path', '-D', 
				help='Full path of processed pickle data', default=PROCESSED_PICKLE_DATA_PATH)
parser.add_argument('--embedding_path', '-s', 
				help='The pre-trained embedding vector', default=EMBEDDING_PATH)
parser.add_argument('--model_path', '-m', help='Full path of  model', default=MODEL_PATH)
parser.add_argument('--epoch', '-e', help='Epochs', type=int, default=EPOCH)
parser.add_argument('--batch_size', '-b', help='Batch size', type=int, default=BATCH)
parser.add_argument('--training_data_ready', '-t', 
				help='Pass when training data is ready', action='store_true')
parser.add_argument('--model_ready', '-M', 
				help='Pass when model is ready', action='store_true')
parser.add_argument('--verbosity', '-v', 
				help='verbosity, stackable. 0: Error, 1: Warning, 2: Info, 3: Debug', action='count')
parser.description = 'Implementation of HAN for Sentiment Classification task'
parser.epilog = "Larry King@https://github.com/Larry955"

相應的變量定義在han_config.py文件中。

詳細參數說明如下:

  • –full_data_path, 要輸入的訓練文件的路徑,該文件必須爲tsv格式
  • –processed_pickle_data_path, 已經處理過的數據集的路徑
  • –embedding_path, 預訓練詞向量文件的路徑
  • –model_path, 保存的模型的路徑
  • –epoch, epoch個數
  • –batch_size, batch size
  • –training_data_ready, 數據集是否已被處理過,顯式輸入該參數時表明數據集已被處理過,否則會報錯
  • –model_ready, 模型是否已保存好,顯式輸入該參數時表明模型已被保存,否則會報錯
  • –verbosity, emmmm…

假設該文件爲HAN. py,那麼輸入

python HAN.py --help

可得:
在這裏插入圖片描述
輸入

python HAN.py --full_data_path=train_data.tsv --embedding_path=GoogleNews-vectors-negative300.bin --epoch=20

表示數據集的路徑爲train_data.tsv,預訓練詞嵌入文件爲GoogleNews,epoch爲20。
輸入

python HAN.py --training_data_ready --model_ready

表示訓練集和模型都已經準備好,可以直接加載。

詞嵌入文件解析

原代碼中只能解析glove詞嵌入,並且詞嵌入維度固定300維,我在重構時對詞嵌入文件進行了簡單的解析,使得模型可以接受不同的詞嵌入文件(目前支持glove和GoogleNews兩種),並能根據文件名提取出詞嵌入的維度。主要代碼如下:

emb_file_flag = ''
embedding_dim = 0

if embedding_path.find('glove') != -1:    
    emb_file_flag = 'glove'     # pre-trained word vector is glove    
    embedding_dim = int(((embedding_path.split('/')[-1]).split('.')[2])[:-1])
elif embedding_path.find('GoogleNews-vectors-negative300.bin') != 
-1:    
    emb_file_flag = 'google'    # pre-trained word vector is GoogleNews    
    embedding_dim = 300

得到詞嵌入文件和維度後,再根據emb_file_flag針對不同的文件獲取詞向量:

embeddings_index = {}
if emb_file_flag == 'glove':    
    f = open(os.path.join(embedding_path), encoding='utf-8')    
    for line in f:        
        values = line.split()        
        word = values[0]        
        vec = np.asarray(values[1:], dtype='float32')        
        embeddings_index[word] = vec    
    f.close()
elif emb_file_flag == 'google':    
    wv_from_bin = KeyedVectors.load_word2vec_format(emb_path, 
binary=True)    
    for word, vector in zip(wv_from_bin.vocab, wv_from_bin.vectors):        
        vec = np.asarray(vector, dtype='float32')        
        embeddings_index[word] = vec

示例:

python HAN.py  --embedding_path=GoogleNews-vectors-negative300.bin  # pre-trained word vector file is GooleNews with 300d
python HAN.py --embedding_path=glove.6B.100d.txt    # pre-trained file is glove with 100d
python HAN.py --embedding_path=glove.6B.200d.txt    # 200d

保存已訓練數據集及模型

原代碼中,每次運行時都要對數據集進行處理,並且要重新訓練模型,這對於百萬級文檔數據集而言十分耗時,爲此,我在重構時設置了相應的參數選項,從而能通過直接加載保存的文件已避免多次訓練,大大降低訓練時間。代碼如下:

  • 保存和加載已訓練數據集
if is_training_data_ready:    
    with open(pickle_path, 'rb') as f:        
        # print('data ready')        
        data, labels, word_index = pickle.load(f)    
    f.close()
else:    
    data, labels, word_index = process_data(data_path)    
    with open(pickle_path, 'wb') as f:        
        pickle.dump((data, labels, word_index), f)    # save trained dataset
    f.close()
    
# Generate data for training, validation and test
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.1, random_state=1)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=1)
  • 保存和加載已訓練模型
if is_model_ready:    
    # print('model ready')    
    model = load_model(model_path, custom_objects={'AttLayer': 
AttLayer})
else:    
# Generate embedding matrix consists of embedding vector    
    embedding_matrix = create_emb_mat(embedding_path, word_index, 
embedding_dim)    # Create model for training    
    model = create_model(embedding_matrix)    
    model.save(model_path)  # save model

需要說明的是,由於該模型中自定義了不在keras.layers中的層(AttLayer),因此直接load_model時會報錯:github:keras/issues/#8612,爲解決該問題,可參考我的另一篇博客:
使用keras調用load_model時報錯ValueError: Unknown Layer:LayerName

添加函數和程序入口

原代碼中只有一個數據預處理函數clean_str和一個類AttLayer,其餘部分混雜其間,導致代碼結構混亂,不易理解,爲此,我在重構時將各項功能以函數形式封裝,並添加主程序入口和註釋,大大提升了代碼的可讀性。此處不再贅述。

實驗結果

這是在IMDB二分類數據集上進行的實驗,共25000條評論,train/val/test的劃分爲8/1/1,epoch爲10,優化函數爲rmsprop。
在這裏插入圖片描述

總結

這次重構基本把原代碼核心功能(模型相關代碼、注意力層AttLayer)以外的部分改得面目全非了,添加了上述功能後,跑模型時可以輸入自己想要的信息,避免在源代碼上進行修改,具有更高的彈性和可讀性,和原來相比好了很多。重構後的代碼見my github-HAN

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章