論文《Matching Networks for One Shot Learning》閱讀

Matching Networks for One Shot Learning

摘要

In this work, we employ ideas from metric learning based on deep neural features and from recent advances that augment neural networks with external memories. Our framework learns a network that maps a small labelled support set and an unlabelled example to its label, obviating the need for fifine-tuning to adapt to new class types. We then defifine

one-shot learning problems on vision (using Omniglot, ImageNet) and language tasks.

1、Introduction

This motivates the setting we are interested in: “one-shot” learning, which consists of learning a class from a single labelled example.

深度學習雖然在很多方面取得來很大的進步,但是缺點之一是需要大量的數據集。數據加強和正則化雖然可以減輕在只有小數據樣本上的過擬合,但是不能從根本上解決這個問題。學習過程非常緩慢並且依賴於大的數據集,需要使用SGD進行很多次的權重更新。這很大程度上是由於模型參數化的原因,模型需要非常緩慢的學習到它的模型。

相反,許多非參數的模型可以快速接受新例子,同時不會產生災難性的遺忘。 一些非參數的模型不需要訓練但是會依賴選擇的度量。我們的目標是將參數化和非參數化模型中的最佳特性結合起來-即快速獲取新的示例,同時從常見的示例中提供優秀的概括。

我們模型的新穎之處在於兩個方面,一方面,在模型層面和訓練過程,我們提出了匹配網絡,一種利用注意力和記憶的最新進展來實現快速學習的神經網絡。另一方面,我們的訓練程序是基於一個簡單的機器學習原則:測試和訓練條件必須匹配。因此,爲了訓練我們的網絡進行快速學習,我們訓練它,每個類只顯示幾個示例,將任務從minibatch切換到minibatch。

在下文將介紹模型,一般設置和實驗。

2 Model

非參數的方法去解決one-shot問題基於兩個方面,首先是模型結構跟隨了神經網絡記憶增強方面的最新進展,給一個支持集,模型定義一個函數,對每一個支持集都做一個映射。其次,採用一種訓練策略,which is tailored for one-shot learning from the support set S.

2.1 Model Architecture

We draw inspiration from models such as sequence to sequence (seq2seq) with attention [2], memory networks [29] and pointer networks [27].

在這些模型中,neural attention mechanism都是完全可微的,被定義去讀記憶矩陣,記憶矩陣中存儲瞭解決任務的有用信息。

set-to-set framework

關鍵點在於在訓練的時候,不需要更改網絡,匹配網絡更夠生成爲沒見過的類生成合理的測試標籤。更準確的說,對於給定的測試樣例x,得到一個輸出y的概率分佈。定義一個映射 P是神經網絡參數化的。然後,對於給的一個新的樣例支持集,我們使用用P定義的參數神經網絡去預測測試樣例x的標籤y。我們的預測輸出實際上是

 

 

上述式子本質上描述了一個新類別的輸出是支持集中標籤的線性組合。其中注意力機制a是一個x×x的內核,上述式子類似一個內核密度估計器。依據一個距離度量和適當的常數,如果的距離超過b,注意力機制就爲0,則上述式子類似‘k-b'最鄰近。然而方程1即包括KDE也包括KNN。我們可以把上述理解爲一種特殊的聯想記憶,給一個輸入,我們指向支持集中相應的示例,找到它的標籤。然而,與其他注意力記憶機制[2]不同,(1)在本質上是非參數的:隨着支持集大小的增加,所使用的記憶也是如此。因此,分類器CS(ˆx)定義的函數形式非常靈活,可以很容易地適應任何新的支持集。

2.1.1 The Attention Kernel

方程一依賴於注意力機制a的選擇,它完全指定了分類。最簡單的形式是在餘弦距離上使用softmax,,並且帶有嵌入函數f,g去嵌入x。

雖然與度量學習有關係,但是我們發現用式子1定義的分類器是對於給定的支持集和分類樣本,讓對充分對齊是足夠的。這種損失還與Neighborhood Component Analysis (NCA) [18], triplet loss [9] or large margin nearest neighbor [28]有關係。

損失是簡單的可微的,所以我們可以找到一個端到端的參數優化。

2.1.2 Full Context Embeddings

模型的主要新穎至於在於重新解釋一個學習的很好的框架(帶有外部記憶的神經網絡)去做one-shot學習。與度量學習密切相關的是,嵌入函數f和g通過對空間特徵X的提升實現最大化在方程一中提到的分類函數的準確率。

儘管分類策略完全依賴於通過設置的整個支持集,我們可以使用餘弦相似性去“attend”,“point”或者是簡單的進行最鄰近計算都是myopic,因爲每一個元素x都是通過g(x)獨立嵌入的。此外,通過函數f,S能夠修改我們是如何嵌入test x的。

我們建議通過一個函數嵌入集合中的元素,該函數除xi外,還包含完整集S,即g變成g(xi,S)。作爲整個集合S的函數,g可以修改如何嵌入。當元素xi和xj非常相似時這是非常重要的,在這種情況下,更改嵌入x的函數是有益的。我們使用雙向長-短期內存(LSTM)[8]在支持集S的上下文中編碼xi,這被認爲是一個序列(更精確的定義見附錄)。

第二個問題可以通過LSTM來固定,在整個集合S上進行read-attention,輸入等價於x:

是從網絡中產生的特徵,輸入到LSTM中,K是LSTM展開步驟(unrolling steps)的固定數目,g(s)是我們得到的集合使用g嵌入。這允許模型去忽視在支持集S中的一些元素,但是把深度加入到attention的計算中。

2.2 Training Strategy

我們定義一個任務T作爲所有可能標籤集合L上的分佈。通常,我們考慮T將所有數據集都統一到幾個唯一的類(例如,5),每個類的示例(例如,最多5個)。在這種情況下,從任務T中取得標籤集合L,L~T,有5到25個樣例。

爲了形成一個“Episode”去計算梯度和更新我們的模型,我們首先從T中取樣L,L是標籤集合。然後我們用L去取樣支持集S和batch B(S和B都是被標記的樣例)。然後匹配網絡被訓練去最小化以支持集S爲條件的批次B中標籤的預測誤差。這是一種元學習的形式,因爲訓練過程明確地學習從給定的支持集學習,以儘量減少一個批的損失。匹配網絡訓練過程中的目標爲:

用方程2訓練θ 產生一個模型,當從新的標籤不同分佈中抽樣時,模型會工作的很好。關鍵的是,我們的模型不需要對它從未見過的類進行任何微調,因爲它的非參數性質。Obviously, as T diverges far from the T from which we sampled to learn θ, the model will not work

3 Related Work

3.1 Memory Augmented Neural Networks

fifixed vectorsmore expressive models

3.2 Metric Learning

Many links between content based attention, kernel based nearest neighbor and metric learningThe most relevant work is Neighborhood Component Analysis (NCA)和非線性的版本。在one-shot學習中使用整個支持集更合適。

4Experiments

我們的所有實驗都圍繞着相同的基本任務:an N-way k-shot學習任務。每種方法都提供了一組K個標記的例子,這些例子來自每一個以前沒有接受過訓練的N個類。任務是將這些無關的無標籤的樣例分類到這N個類中。我們將多個備選模型(作爲基線)與匹配網絡進行了比較。

4.1 Image Classifification Results

For vision problems, we considered four kinds of baselines: matching on raw pixels, matching on discriminative features from a state-of-the-art classififier (Baseline Classififier), MANN [21], and our reimplementation of the Convolutional Siamese Net.

基準分類器被訓練去分類一張圖片到一個在訓練集中原始出現的類,但是不包括之前說的N類。我們使用這個網絡並且使用最後一層的特性(在softmax之前)進行最鄰近匹配,這在很多任務中都取得了較好的結果。在【11】之後,將卷積連網訓練成原來訓練數據集中相同或不同的任務,然後使用最後一層進行最近鄰匹配。

我們還嘗試僅使用從L採樣的支持集S來進一步微調特徵。這產生了大量的過度擬合,但是考慮到我們的網絡是高度正則化的,可以產生額外的增益。

 

4.1.1 Omniglot

Omniglot包括1623個字符,它們來自50個不同的字母表。每一個都是由20個不同的人繪製的。這有很多個類別但是每個類只有很少的樣例。

The N-way Omniglot task:選擇N個沒有見過的字符類,獨立於字母表作爲L。爲模型提供一個類別一張圖片,作爲S~L,B~L。我們通過旋轉90的倍數來增強數據,使用1200個字符作爲訓練剩下的用作評估。

我們使用CNN作爲嵌入函數,由一堆模塊組成,每一個都是3×3的卷積帶有64個濾波器,然後是批量歸一化,一個非線性的ReLu,2×2的最大池化層。我們將圖片調整爲28×28,這樣我們使用4個模塊就可以1×1×64的結果特徵映射,從而產生我們的嵌入函數f(x)。一個全連接層後跟着softmax用來定義基準分類器。

  1. shot,5-shot,5-way,20-way,我們的模型都比基準表現好。對於k-shot分類器使用更多的樣例是有幫助的;5-way分類比20-way分類更簡單。

4.1.2 ImageNet

我們實驗的設置和Omniglot相同,但是我們考慮一個rand和dogs設置。在rand設置中,我們隨機從訓練集中移除了118個標籤,然後訓練只在這118個類中,我們表示爲。對於dogs設置,我們從狗的後代中刪除了所有類別(共118個)然後在沒有狗的類別上進行訓練,然後在狗的類別上進行測試,。我們設計了一個新的數據集minImageNet,包括60000張大小爲84×84的彩色圖片,有100個類,每一個有600個樣例。我們使用80個類訓練,然後再剩餘的20個類上進行測試,所以我們現在有 randImageNet,dogsImageNet,miniImageNet。

和Omniglot一樣,匹配網絡比基準網絡表現的好。但是miniImageNet比Omniglot任務要難,它讓我們去評估Full Contextual Embeddings的靈活度。不管有沒有微調,FCE提高了匹配網絡的表現。

 

我們在全尺寸的InamgeNet上做實驗。我們的基準分類器是Inception。We also compared to features from an Inception Oracle classififier trained on all classes in ImageNet, as an upper bound. 我們用從Inception 分類器上得到的參數初始化匹配網絡的特徵提取器f和g,而不是從零開始在這些大的任務上訓練匹配網絡,然後我們進一步在5-way 1-shot任務上訓練數據集,結合Full Context Embeddings和我們的匹配網絡和訓練策略。

randImageNet和dogsImageNet的結果展示在表3中。Inception Oracle的表現接近完美。

當僅在上訓練時,匹配網絡比Inception提高了進6%,當在上訓練時,將錯誤率減半。從所有的錯誤來看,“盜夢空間”有時似乎更喜歡圖像,而不是其他圖像(這些圖像往往像第二列中的示例一樣混亂,或者顏色更恆定)。另一方面,匹配網設法從支持集S’中出現的這些異常值中恢復。

如果我們將我們的訓練策略調整爲來自細粒度集的樣本S而不是從Image Net類樹的葉子上統一採樣F標籤,可以實現改進。我們把這作爲今後的工作。

4.1.3 One-Shot Language Modeling

任務如下:給一個缺少詞的查詢句,和一組支持句,每個句子都有一個缺失的單詞和一個對應的標籤,從支持集中選擇最匹配查詢句的標籤。

 

句子來自the Penn Treebank dataset。在每一次試驗中,我們確保集合和批處理都填充了不重疊的句子。這意味着我們不使用頻率很低的單詞。和圖片任務一樣, each trial consisted of a 5 way choice between the classes available in the set。在整個句子匹配任務中,我們使用了20的批處理大小,並且在k=1,2,3之間改變了設置大小。我們確保每一組都有相同數量的句子可供使用。我們將單詞分成隨機抽樣的9000個用於訓練,1000個用於測試,我們使用標準測試集來報告結果。因此,無論是單詞還是句子在測試期間都沒有在訓練期間見過。

我們將我們的one-shot匹配模型與oracle LSTM進行比較。在設置中,LSTM具有一個不公平的優勢,因爲它不是做one-shot學習而是看到所有的數據,所有這應該被當作一個上限。我們檢驗一個相似的設置,其中給模型一個待著一個空的句子並且還有五個可能的單詞,其中包括正確答案。對於這五個詞,模型給出了一個對數似然,並選擇其中數值最大的。

LSTM語言模型oracle在測試集上達到了72.8的準確率。. Matching Networks

with a simple encoding model achieve 32.4%, 36.1%, 38.2% accuracy on the task with k = 1, 2, 3 examples in the set, respectively.

Two related tasks are the CNN QA test of entity prediction from news articles [5], and the Children’s Book Test (CBT) 

5 Conclusion

我們在這篇論文中介紹了匹配網絡,一種新的神經網絡結構,通過相應的訓練制度,能夠對各種one-shot分類任務執行最新的性能。這裏有幾個關鍵的點,首先,如果你訓練網絡進行one-shot學習,那麼one-shot學習就變得容易。其次,神經網絡中的非參數結構使得網絡在相同的任務中更容易記憶和適應新的訓練集。將這些觀測結果結合起來,產生匹配網絡。我們模型的一個明顯的缺點是,隨着支持集的大小增加,每次梯度更新的計算變得更加昂貴。儘管有稀疏的和基於抽樣的方法來緩解這一問題,但我們未來的許多努力將集中在這一限制上。此外,如ImageNet dogs子任務中所示,當標籤分佈具有明顯的偏差(例如細粒度)時,我們的模型會受到影響。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章