解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵交互

近日,阿里雲人工智能平臺PAI與浙江大學吳健、應豪超老師團隊合作論文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在國際人工智能頂會AAAI-2024上發表。本項工作聚焦於深度表格學習中的一個核心問題:在處理結構化表格數據(tabular data)時,深度模型是否擁有有效的歸納偏差(inductive bias)。我們提出算術特徵交互(arithmetic feature interaction)對深度表格學習是至關重要的假設,並通過創建合成數據集以及設計實現一種支持上述交互的AMFormer架構(一種修改的Transformer架構)來驗證這一假設。實驗結果表明,AMFormer在合成數據集表現出顯著更優的細粒度表格數據建模、訓練樣本效率和泛化能力,並在真實數據的對比上超過一衆基準方法,成爲深度表格學習新的SOTA(state-of-the-art)模型。

背景

圖1:結構化表格數據示例,引用自[Borisov et al.]

結構化表格數據——這些數據往往以表(Table)的形式存儲於數據庫或數倉中——作爲一種在金融、市場營銷、醫學科學和推薦系統等多個領域廣泛使用的重要數據格式,其分析一直是機器學習研究的熱點。表格數據(圖1)通常同時包含數值型(numerical)特徵和類目型(categorical)特徵,並往往伴隨有特徵缺失、噪聲、類別不平衡(class imblanance)等數據質量問題,且缺少時序性、局部性等有效的先驗歸納偏差,極大地帶來了分析上的挑戰。傳統的樹集成模型(如,XGBoost、LightGBM、CatBoost)因在處理數據質量問題上的魯棒性,依然是工業界實際建模的主流選擇,但其效果很大程度依賴於特徵工程產出的原始特徵質量。

隨着深度學習的流行,研究者試圖引入深度學習端到端建模,從而減少在處理表格數據時對特徵工程的依賴。相關的研究工作至少可以可以分成四大類:(1)在傳統建模方法中疊加深度學習模塊(通常是多層感知機MLP),如Wide&Deep、DeepFMs;(2)形狀函數(shape function)採用深度學習建模的廣義加性模型(generalized additive model),如 NAM、NBM、SIAN;(3)樹結構啓發的深度模型,如NODE、Net-DNF;(4)基於Transformer架構的模型,如AutoInt、DCAP、FT-Transformer。儘管如此,深度學習在表格數據上相比樹模型的提升並不顯著且持續,其有效性仍然存在疑問,表格數據因此被視爲深度學習尚未征服的最後堡壘。

算術特徵交互在深度表格學習的“必要性”

我們認爲現有的深度表格學習方法效果不盡如人意的關鍵癥結在於沒有找到有效的建模歸納偏差,並進一步提出算術特徵交互對深度表格學習是至關重要的假設。本節介紹我們通過創建一個合成數據集,並對比引入算數特徵交互前後的模型效果,來驗證該假設。

合成數據集的構造方法如下:我們設計了一個包含八個特徵( X^{_{1}},...X^{_{8}} )的合成數據集。

 
 

圖2:合成數據集上的結果對比。圖中+x%表示AMFormer相比Transformer的相對提升。

在上述數據中,我們將引入了算數特徵交互的AMFormer架構與經典的XGBoost和Transformer架構對比。實驗結果顯示:

以上結果共同證實了算術特徵交互在深度表格學習中的顯著意義。

算法架構

圖3:AMFormer架構,其中L表示模型層數。

本節介紹AMFormer架構(圖3),並重點介紹算數特徵交互的引入。AMFormer架構借鑑了經典的Transformer框架,並引入了Arithmetic Block來增強模型的算術特徵交互能力。在AMFormer中,我們首先將原始特徵轉換爲具有代表性的嵌入向量,對於數值特徵,我們使用一個1輸入d輸出的線性層;對於類別特徵,則使用一個d維的嵌入查詢表。之後,這些初始嵌入通過L個順序層進行處理,這些層增強了嵌入向量中的上下文和交互元素。每一層中的算術模塊採用了並行的加法和乘法注意力機制,以刻意促進算術特徵之間的交互。爲了促進梯度流動和增強特徵表示,我們保留了殘差連接和前饋網絡。最終,依據這些豐富的嵌入向量,AMFormer使用分類或迴歸頭部生成最終輸出。

算術模塊的關鍵組件包括並行注意力機制和提示標記。爲了補償需要算術特徵交互的特徵,我們在AMFormer中配置了並行注意力機制,這些機制負責提取有意義的加法和乘法交互候選者。這些交互候選隨着會沿着候選維度被串聯(concatenate)起來,並通過一個下采樣的線性層進行融合,使得AMFormer的每一層都能有效捕捉算術特徵交互,即特徵上的四則算法運算。爲了防止由特徵冗餘引起的過擬合併提升模型在超大規模特徵數據集上的伸縮,我們放棄了原始Transformer架構中平方複雜度的自注意力機制,而是使用兩組提示向量(prompt token vectors)作爲加法和乘法查詢。這種方法爲AMFormer提供了有限的特徵交互自由度,並且作爲一個附帶效果,優化了內存佔用和訓練效率。

以上是AMFormer在架構層引入的主要創新,關於模型更詳細的實現細節可以參考原文以及我們的開源實現。

進一步實驗結果

表1:真實數據集統計以及評估指標。

爲了進一步展示AMFormer的效果,我們挑選了四個真實數據集進行實驗。被挑選數據集覆蓋了二分類、多分類以及迴歸任務,數據集統計如表1所示。

表2:AMFormer以及基準方法的性能對比,其中括號內的數字表示該方法在當前數據集上表現的排名,最優以及次優的結果分別以加粗以及下劃線突出。

我們一共測試了包含傳統樹模型(XGBoost)、樹架構深度學習方法(NODE)、高階特徵交互(DCN-V2、DCAP)以及Transformer派生架構(AutoInt、FT-Trans)在內的六個基準算法以及兩個AMFormer實現(分別選擇AutoInt、FT-Trans做基礎架構,即AMF-A和AMF-F),結果彙總在表2中。

在一系列對比實驗中,AMFormer表現更突出。結果顯示,基於MLP的深度學習方法如DCN-V2在表格數據上的性能不盡如人意,而基於Transformer架構的模型顯示出更大的潛力,但未能始終超過樹模型XGBoost。我們的AMFormer在四個不同的數據集上,與所有六個基準模型相比,表現一致更優:在分類任務中,它將AutoInt和FT-transformer的準確率或AUC提升至少0.5%,最高達到1.23%(EP)和4.96%(CO);在迴歸任務中,它也顯著減少了平均平方誤差。相比其它深度表格學習方法,AMFormer具有更好的魯棒和穩定性,這使得在性能排序中AMFormer斷層式優於其它基準算法,這些實驗結果充分證明了AMFormer在深度表格學習中的必要性和優越性。

結論

本工作研究了深度模型在表格數據上的有效歸納偏置。我們提出,算術特徵交互對於表格深度學習是必要的,並將這一理念融入Transformer架構中,創建了AMFormer。我們在合成數據和真實世界數據上驗證了AMFormer的有效性。合成數據的結果展示了其在精細表格數據建模、訓練數據效率以及泛化方面的優越能力。此外,對真實世界數據的廣泛實驗進一步確認了其一致的有效性。因此,我們相信AMFormer爲深度表格學習設定了強有力的歸納偏置。

進一步閱讀

● 論文標題:

Arithmetic Feature Interaction is Necessary for Deep Tabular Learning

● 論文作者:

程奕、胡仁君、應豪超、施興、吳健、林偉

● 論文PDF鏈接:https://arxiv.org/abs/2402.02334

● 代碼鏈接:https://github.com/aigc-apps/AMFormer

原文鏈接

本文爲阿里雲原創內容,未經允許不得轉載。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章