【論文解讀】 Neural Architecture Search with reinforcement learning

【論文解讀】 Neural Architecture Search with reinforcement learning


論文作者:Barret Zhph*,Quoc V.Le [來自於 Google Brain]

論文標題:Neural Architecture Search with reinforcement learning/(使用強化學習進行神經網絡架構搜索)

論文會議:ILCR 2017

論文鏈接:https://arxiv.org/abs/1611.01578

論文代碼: https://github.com/tensorflow/models

1. 概述與介紹


雖然神經網絡卻得了巨大的成功和發展在過去幾年,從 SIFT 與 HOG,到 AlexNet,VGGNet,GoogleNet
但是想要設計一種神經網絡模型仍然需要很多的專業背景知識和時間,本片文章提出了一種叫做 NAS(Neural Architecture Search)
的方法,這種方法基於梯度,用於找到非常好的 Architecture。NAS 的結構如下圖所示:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-ggHwUv2N-1584100103921)(./image/1.png)]

  • 本文提出的方法:
    • 使用 RNN 去生成神經網絡的模型
    • 同時使用強化學習去訓練 RNN,從而希望最大化生成的架構在驗證集上的最大化。
  • 本文現在達到的結果:
    • 計算機視覺任務: - CIFAR-10 數據集上: - test error rate: 0.09 的提升,速度是 1.05 倍更快 - 自然語言處理任務: - Penn Treebank dataset[1]: - 提出的新的 recurrent cell 超過 LSTM,還有 baseline
    • 新的 cell 在 perplexity[2]上相比較 state-of-the-art 有 3.6 的提升
      • PTB 數據集: - 在字符級別的語言模型任務上也是達到了 state-of-the-art
  • 具體的做法: 主要就是將神經網絡模型具體到變成一個可變長度的字符串,所以便於使用 RNN(控制器的作用)來生成字符串.
    用於構建網絡。之後訓練該網絡,並用網絡的 accuracy 作爲 reward 返回給控制器來更新控制器的參數,達到更優的策略。

2. 相關工作


Point1: 超參數優化

Hpyerparameter optimization在過去取得了很多的應用與成功,但是隻能限制於固定長度的空間
中,生成一個規定網絡結構和連接性的可變長度配置很困難。在實踐中,如果這些方法提供了良好的初始模型這些方法往往會更好地工作

(1) 早期的解決方式: 使用貝葉斯優化方法,可以用來搜索非定長的 architecture

(2) 與本文的方法相比: 貝葉斯優化方法實在是不夠通用與靈活

Something we Should know:

神經網絡訓練是由於許多超參數決定的:網絡深度,學習率,卷積核大小,那麼如何獲得一個更好的
超參數組合? 常用的就是Grid Search,Random Search以及貝葉斯優化搜索

1. Grid Search: 其實就是窮舉搜索 超參數的組合全部實現不夠現實,所以事先限定各種可能,效率低
2. Random Search: 缺點是,隨機搜索的結果互相之間差異很大,但是比Grid Search高效

貝葉斯優化:https://www.cnblogs.com/marsggbo/p/9866764.html

Point2: 現代神經進化算法

Modern neuo-evolution algorithms 在組成新的模型的時候很靈活,但是在大規模的時候
不夠實用,限制主要在於是基於搜索的算法,所以運行的時候很慢同時需要啓發式才能夠 work well

Point3: 神經架構搜索

Neural Archhitecture Search 與程序合成和歸納編程有一些相似之處,它們從例子中搜
索程序。在機器學習中,概率性程序引導已成功用於許多環境中,比如解決簡單問答,
對列表數字排序,並以少樣本進行學習。

Point4: idea 的產生以及我們做了什麼!

神經架構搜索中的控制器是自迴歸的,自迴歸就是預測一次的超參數,以先前的預測爲條件。
這個 idea 是借鑑了端到端序列譯碼器對序列學習的思想。與 seq2seq learning 不同,
我們的方法優化了一個不可微的目標,child network 的 accuracy。
類似於神經機器翻譯中的 BLEU[3]優化工作。與這些方法不同,
我們的方法直接從獎勵信號中學習,並且沒有使用有監督的 bootstrapping(自助採樣方法)

與我們的工作相關的還有學習學習或元學習[4]的 idea
這是一個使用在一項任務中學到的信息來改進未來任務的通用框架。
更密切相關的是使用神經網絡學習另一網絡的梯度下降更新
以及使用強化學習爲另一網絡找到更新策略的想法。

端到端與非端到端:
相對於深度學習,傳統機器學習的流程往往由多個獨立的模塊組成,比如在一個典型的自然語言處理(Natural Language Processing)問題中,包括分詞、詞性標註、句法分析、語義分析等多個獨立步驟,每個步驟是一個獨立的任務,其結果的好壞會影響到下一步驟,從而影響整個訓練的結果,這是非端到端的。
而深度學習模型在訓練過程中,從輸入端(輸入數據)到輸出端會得到一個預測結果,與真實結果相比較會得到一個誤差,這個誤差會在模型中的每一層傳遞(反向傳播),每一層的表示都會根據這個誤差來做調整,直到模型收斂或達到預期的效果才結束,這是端到端的。
兩者相比,端到端的學習省去了在每一個獨立學習任務執行之前所做的數據標註,爲樣本做標註的代價是昂貴的、易出錯的。

3. 方法 METHODS


Section 1. 使用 RNN 作爲控制器來生成模型描述

下圖所示是論文中提到如何使用 RNN 去預測生成一個簡單的 conv 層的超參數

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-UiCEjxN2-1584100103924)(./image/2.png)]

  • 上圖中預測的網絡知識包括 conv 層,使用 RNN 預測生成 conv 層的超參數,主要包括
    卷積核的 height,卷積核的 Width,卷積核滑動 stride 的 Height,卷積核滑動 stride 的 Width
  • 實驗當中終止的條件是當網絡層數達到一個值的時候就會停止
  • 控制器生成一個網絡結構後,使用訓練數據集進行訓練直到達到收斂,然後再 hand-out 驗證集上進行測試
    得到一個準確率。
hand-out validation set: 實際上就是留出的驗證集,將數據集分爲訓練集S,和測試集T,同時兩個集合互斥

Section 2. 使用強化學習的思想進行訓練

(1) 主要是思想爲: RNN 的參數使用θc\theta_{c}表示,controllercontroller所預測的一系列tokenstokens記爲一系列的actionsactions,
a1:Ta_{1:T},這些 tokens 是爲了ChildnetworkChild network,子網絡再驗證集上得到的準確率用RR進行表示,這種準確率稱爲
rewardsignalreward signal,並且使用強化學習來訓練controllercontroller

(2) 目標函數: 如下圖所示,實際上就是需要maximizerewardmaximize reward來找到最優的結構

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-YfAIso46-1584100103925)(./image/3.png)]

由於獎勵信號 R 是不可微分的,因此使用策略梯度迭代的去更新θc\theta_{c},在本文中,使用到來自 Williams(1992)Williams (1992) 的強化學習規則
如下圖所示

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-JALjKNYh-1584100103926)(./image/4.png)]

對於上面這個等式實際上就約等於如下所示的等式:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-s1YiNaIP-1584100103928)(./image/5.png)]

公式推導如下所示:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-F2GpK5lD-1584100103929)(./image/10.png)]

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-IMVh2M63-1584100103929)(./image/11.png)]

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Crs63cH1-1584100103930)(./image/12.png)]

公式參數解讀:
m:是控制器在一個batch中採樣得到的結構數量
T:是controller用於預測和設計神經網絡結構的超參數的數量
R_k表示第k個網絡結構在驗證集上的準確度

上述的更新算法是對梯度的無偏估計,但是缺點在於方差太高,解決方法如下圖所示,採用了一個 baseline 函數

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-nmvJ7Ezf-1584100103931)(./image/6.png)]

其中bb不依賴於當前的 action,那麼其仍是無偏梯度估計,且bb是前面結構準確度的指數平均指標(Exponential Moving Average,EMA)

無偏估計:無偏估計是用樣本統計量來估計總體參數時的一種無偏推斷。估計量的數學期望等於被估計參數的真實值,則稱此估計量爲被估計參數的無偏估計,即具有無偏性,是一種用於評價估計量優良性的準則。無偏估計的意義是:在多次重複下,它們的平均數接近所估計的參數真值。無偏估計常被應用於測驗分數統計中。

EMA(Exponential Moving Average)是指數平均數指標,它也是一種趨向類指標,指數平均數指標是以指數式遞減加權的移動平均。
用在這裏的目的就是第i步的梯度下降的步長種增加了權重係數,相當於做了一個learning rate decay.

(3) 使用並行算法和異步來進行加速學習(氪金的味道,Google 親爹)

每一次用於更新 controller 的參數θc\theta_{c}的梯度都對應於一個子網絡訓練達到收斂。但是因爲子網絡衆多,
且每次訓練收斂耗時長,所以使用 分佈式訓練和異步參數更新的方法來加速 controller 的學習速度。具體結構如下
圖所示:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-P2ewICSU-1584100103932)(./image/7.png)]

主要就是有 S 個 Parameter Server 用於存儲 K 個 Controller 複製體的共享參數,然後然後每個
Controller Replica 生成 m 個並行訓練的自網絡。

controller 會根據 m 個子網絡結構在收斂時得到的結果收集得到梯度值,然後爲了更新所有 Controller Replica,會把梯度值傳遞給 Parameter Server。
在本文中,當訓練迭代次數超過一定次數則認爲子網絡收斂。

Section 3: 使用跳躍連接和其他 Layer Types 來提升架構的複雜度

這個 Section 實際就是講需要使用 Skip ConnecionS(ResNet 結構)和 branching Layers(層分叉,GoogleNet 結構)。
同時爲了準確預測 connections,本文使用了基於注意力機制的 set-selection type attenion 方法。

方法:

(1)每個 layer 添加一個 anchor point, 則經過 anchor point, RNN Controller 有一個 hidden state hih_i

(2)在 N 層,根據 sigmoid 函數採樣 j 是否連接到 i layer,sigmoid 函數如下所示:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Ozw6COhk-1584100103933)(./image/8.png)]

上式中Wcurr,Wprev,vW_{curr},W_{prev},v是可以學習的參數

下圖實質表示的是如何使用 skipconnectionsskip connections 去決定那一層是其想要輸入的當前層

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-pdeXX28h-1584100103934)(./image/9.png)]

在本文當中對幾個問題進行了處理:

(1) 如果沒有輸入,那麼原始圖像作爲輸入,看成是 inputlayerinput layer

(2) layerlayer 輸出可能沒被送到任何其他 layerslayers:都送到classifierclassifier

(3) 如果需要 concatenatedconcatenated 的輸入層有不同的sizesize,那麼小一點的層通過補 0 來保證一樣大小

添加其他類型 layerslayers

pooling, batchnorm, 甚至是Learning rate

RNN Controller 首先預測 layer type,再預測相關的 hyperparameters

Section 4: 生成 RECURRENT CELL 架構

主要是講如何生成遞歸單元結構的具體細節,使用樹結構來描述網絡結構,這樣也便於便利結點
,其中每棵樹由兩個葉子節點(0,1)和中間節點(用 2 表示)組成.這種結構也可以稱爲"base 2"結構。
具體如下所示:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-0V4ixIis-1584100103935)(./image/13.png)]

4. 實驗與結果


  • 主要是在 CIFAR-10 數據集上進行圖像分類以及在 Penn Treebank 上進行 language modeling 任務

  • 其中數據集,baseline 模型以及超參設置,具體可見論文。

5.結論與總結


(1) NAS 在生成網絡的時候之前需要固定網絡的結構,或者是說需要固定網絡的層數。

(2) 以生成 CNN 網絡爲例,代碼中默認最大層數參數 max_layers=2,當然也可以人爲修改。

(3) 而 controller 其實就是一個 RNN 網絡,其輸出數據表示某一層中各個節點的參數,各個參數是按順序輸出的。
例如代碼中是按照[cnn_filter_size,cnn_num_filters,max_pool_ksize,cnn_dropout_rates] 輸出。

僞代碼:

state = np.array([[10.0, 128.0, 1.0, 1.0]*max_layers], dtype=np.float32) # 初始化state
for episode in range(MAX_EPISODES):
	action = RLnet.get_action(state)  # 強化學習網絡根據當前狀態獲取下一步的動作,其中是使用原論文所給的NAScell來對動作進行預測的。
	reward, pre_accuracy = net_manager.get_reward(action) # 根據生成的動作得到對應的網絡,然後將該網絡在訓練集上訓練至收斂,再將收斂後的網絡在驗證集上運行得到準確度,根據一定的準則將準確度轉化爲reward。
	reward = update(reward) # 更新reward
	state = update(action) # 根據action更新state,在例子中是state=action[0]

從上面的僞代碼可以看出每次採樣得到的模型都需要在訓練集上訓練到收斂,然後再根據在驗證集上得到的 reward 更新。所以 NAS 其本質是在離散搜索空間進行搜索,
而且網絡拓撲結構是固定的,並且訓練時間較長。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-U1kNArQg-1584100103936)(./image/15.png)]

代碼復現:https://github.com/wallarm/nascell-automl

一些內容的解釋:
[1] Penn Treebank dataset: 預料來自於華爾街日報,對預料進行標註,標註性內容包括詞性分析標註以及句法分析
[2] perplexity:困惑度,每個時間步根據語言模型計算的概率分佈隨機挑詞,那麼平均情況下,挑多長時間才能挑來正確的那個
[3] BLEU:bilingual evaluation understudy,即:雙語互譯質量評估輔助工具。它是用來評估機器翻譯質量的工具,BLEU的設計思想與評判機器翻譯好壞的思想是一致的:機器翻譯結果越接近專業人工翻譯的結果,則越好。BLEU算法實際上在做的事:判斷兩個句子的相似程度。我想知道一個句子翻譯前後的表示是否意思一致,顯然沒法直接比較,那我就拿這個句子的標準人工翻譯與我的機器翻譯的結果作比較,如果它們是很相似的,說明我的翻譯很成功。因此,BLUE去做判斷:一句機器翻譯的話與其相對應的幾個參考翻譯作比較,算出一個綜合分數。這個分數越高說明機器翻譯得越好。(注:BLEU算法是句子之間的比較,不是詞組,也不是段落)
[4] 元學習:Meta Learning 元學習或者叫做 Learning to Learn 學會學習.元學習通常被用在:優化超參數和神經網絡、探索好的網絡結構、小樣本圖像識別和快速強化學習等。
論文寫作鑑賞:
  start from scratch: 白手起家
  paradigm sift: 範式轉移,形容變革
參考文獻:

https://blog.csdn.net/stay_foolish12/article/details/91554801

https://blog.csdn.net/Lucifer_zzq/article/details/83188462

https://www.zhihu.com/question/50454339/answer/257372299

https://zhuanlan.zhihu.com/p/54361495

https://www.jianshu.com/p/4c17bef0ff85#fnref3
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章