使用模型平均下的深度網絡的聯邦學習

本文出自Google的Federated Learning of Deep Networks using Model Averaging,主要介紹使用模型平均方法的聯邦式學習。

引言

豐富的數據通常是隱私敏感性的,數量大或者兩者都有,這將會妨礙使用傳統方法登錄到數據中心並在此進行訓練。我們提倡將訓練數據分佈在移動設備上,並通過聚合本地計算的更新來學習一個共享模型。這種方法被稱爲聯邦學習。我們提出了一種實用的深度網絡聯邦學習方法,證明該方法對自然生成的不平衡和無IID(獨立且恆等分佈)數據分佈具有魯棒性。這種方法運行在相對較少的通信環境(聯邦學習的主要約束)中訓練高質量的模型。關鍵點是:儘管我們優化了非凸性損失函數,對來自多個客戶機的更新進行參數平均會產生較好的結果。

一、簡介

  1. 存在許多用於分佈式優化的算法,但這些算法通常具有通信需求且只被一個數據中心網絡結構所滿足。這些算法的理論合理性和實際性能在很大程度上取決於假定數據在計算節點上是獨立且恆等分佈的。綜上所述,這些需求相當於假設完整的訓練集由建模器控制並存儲在一個集中的位置。
  2. 我們研究了一種學習技術,它允許用戶在不需要集中存儲數據的情況下,從這些豐富的數據中去獲得共享模型的好處。這種方法還允許我們利用網絡邊緣可用的廉價計算來擴展學習任務。我們將這種方法稱爲聯邦學習,因爲學習任務是由一箇中央服務器協調的鬆散的參與設備聯合(客戶機)來解決的。每個客戶端都有一個本地的訓練數據集且從未上傳到服務器。相反,每個客戶端計算對由中央服務器維護的當前全局模型的更新,並且只有這個更新可以被交流。由於這些更新是特定於改進當前的模型,所以一旦它們被應用了就沒有理由去存儲它們。
  3. 我們引入了聯邦平均算法,它將每個客戶機上本地SGD訓練與中央服務器執行模型平均的通信循環結合起來。實驗證明:它對於不平衡和無IID的分佈式數據具有魯棒性,並且可以將訓練深度網絡所需要的通信週期減少一到兩個數量級。

二、聯邦式學習

  1. 聯邦學習適合任務的特點:(1)對來自移動設備的真實數據的訓練比對數據中心通常可用的代理數據的訓練具有明顯優勢; (2)這種數據是隱私敏感性的或者規模較大的,因此它不適合將其記錄到數據中心來進行模型訓練;(3)對於監督型任務,數據集上的標籤可以從用戶與他們設備的交互過程中自然推理出來。
  2. 應用實例:(1)圖像分類,例如預測哪些照片最有可能在未來被查看或分享多次;(2)語言模型,可以通過提高解碼、下個單詞預測甚至預測整個回覆來提高觸摸屏鍵盤上的語音識別和文本輸入。
  3. 這些訓練樣本的分佈與容易獲取的代理數據集有很大的不同:聊天和短信中使用的語言與標準語料庫有着很大的不同,人們在手機上拍攝的照片很可能與網絡上的普通照片有所不同。另外,這些問題的標籤集都是直接可用的:輸入的文本是學習語言模型的自標籤,照片標籤是可以通過用戶與照片應用程序的交互過程來定義(哪些照片被刪除、分享或查看)。
  4. 聯邦學習的隱私性:(1)我們必須考慮到一個攻擊者可能通過檢查模型參數學習到什麼內容,這些參數被優化過程中參與到的客戶機所共享。對於真正隱私敏感性的任務,差異化隱私技術可以提供嚴格的最壞情況下的隱私保證,即使對方有任意的邊緣信息。(2)下一個問題是對方可以通過訪問到單個客戶的更新信息來學習到什麼內容。如果一方信任中心服務器,那麼加密和其他標準化安全協議對於這種類型的攻擊是一個基礎的防護。一個強有力的保證可以通過強制執行本地差異化隱私來實現,在這種情況下,我們不是向最終的模型添加噪音干擾,而是干擾每個更新,從而阻止中央服務器對客戶機進行任何確定的推斷。也可以使用安全多方計算來對多個客戶機更新執行聚合,允許使用更少的隨機噪聲來實現本地的差異化隱私。
  5. 大型數據的優勢:在數據中心訓練每個客戶機所需的網絡流量只是每個客戶機本地數據集的大小,必須要被傳輸一次。對於聯邦學習,每個客戶機的流量是每輪的通信量乘以更新的規模。如果更新規模相對於所需的訓練數據數量要小,則後一種數量要小得多。

三、聯邦式優化

  1. 聯邦優化與典型的分佈式優化問題有幾個關鍵的區別:(1)Non-IID:給定客戶端的訓練數據集通常爲基於特定用戶對移動設備的使用情況,因此,任何特定用戶的本地數據集都不能代表總體分佈。(2)不平衡性:一些用戶會頻繁地使用產生訓練集的服務或應用程序,導致一些客戶端有着大量的本地訓練集,而另外一些用戶則只有很少或沒有數據。(3)大規模分佈式:在實際場景中,我們期望參與優化的客戶機數量要遠大於每個客戶機的平均樣例數量。
  2. 我們假設同步更新方案在幾輪通信過程中進行,這裏有一組固定的K個客戶端,每個客戶端都有一個固定的本地數據集。在每輪的開始,客戶端的一個隨機分數C被選擇好,然後服務器將當前全局算法狀態(當前的模型參數)發送到每個客戶端。每個客戶端執行基於全局狀態和本地數據集的本地計算,然後發送更新到服務器。服務器應用這些更新到它的全局狀態,並重復此過程。
  3. 對於一個機器學習問題,我們通常令:fi(w)=ϱ(xi,yi;w)f_i(w)=\varrho(x_i,y_i;w) ,這是模型參數w對示例(xi,yi)(x_i,y_i)的預測損失。我們假定這裏有K個客戶端用於分享數據,用PkP_k表示客戶端k上數據點的索引值,令nk=Pkn_k=|P_k|。於是我們可以使:
              f(w)=k=1KnknFk(w)whereFk(w)=1nkiϵPkfi(w).f(w)=\sum_{k=1}^{K}\frac{n_k}{n}F_k(w) \quad where \quad F_k(w)=\frac{1}{n_k}\sum_{i\epsilon P_k}f_i(w).
  4. 在聯邦優化中,通信代價佔據主導地位。此外,我們希望每個客戶端每天只參與一個小數量的更新循環。另一方面,因爲任何單個設備上的數據集都小於總數據集的大小,且現代智能手機有着相對較快的處理器(包括GPU),與許多型號的通信成本相比,計算基本是免費的。因此,我們的目標是使用額外的計算,目的是減少訓練一個模型所需要的通信循環次數。這裏有兩種我們添加計算的方式:(1)提高並行性:使用更多的客戶端在每個通信循環中獨立工作;(2)增加每個客戶端的計算量。
  5. 我們考慮的(參數化)算法集的一個端點是簡單的一次平均,每個客戶端都爲模型求解來使本地數據損失函數最小化(可能爲正則化),然後這些模型被平均得到最終的全局模型。這種方法在獨立且恆等分佈形式的數據凸情況下進行了廣泛研究,在最壞的情況下,生成的全局模型並不比在單個客戶端上訓練的模型要好。

四、聯邦平均算法

  1. SGD可以很自然地應用於聯邦優化問題,每輪通信只執行一個小批量梯度計算(比如在隨機選擇的客戶端上)。這種方法計算效率高,但是需要大量的訓練才能產生好的模型。這種方式的計算量由三個關鍵參數控制:C(每輪執行計算的客戶端設備的比例);E(每個客戶機在每輪上通過其本地數據集執行的訓練次數);B(用於客戶機更新的批處理大小),其中將 B設爲無窮代表將整個本地數據集看作單個批處理量。

  2. 我們令B=無窮,E=1來生成一種可變小批量規模的SGD形式,這個算法每輪選擇C的客戶端比例,並計算這些客戶端所擁有的所有數據的損失梯度,因此C=1相當於全批次(非隨機)的梯度下降。我們將這種算法稱爲聯邦式SGD,而批次選擇機制不同於通過均勻地隨機選擇單個樣例來選擇批次,其批梯度g仍滿足於E[g]=f(w)E[g]=\triangledown f(w)

  3. 帶有一個固定的學習率η\eta的分佈式梯度下降的典型實現有着每個客戶端k計算g(k)=Fk(wt)g(k)=\triangledown F_k(w_t),即當前模型wtw_t的本地數據集上平均梯度,然後中央服務器將這些梯度集合起來並應用於更新中:
                   wt+1wtηk=1Knkngkw_{t+1} \leftarrow w_t-\eta \sum_{k=1}K\frac{n_k}{n}g_k
    因爲ηk=1Knkngk=f(wt)\eta \sum_{k=1}K\frac{n_k}{n}g_k=\triangledown f(w_t),所以上述表達式也可以表示爲:
              k,wt+1kwtηgkandwt+1k=1Knknwt+1k.\forall k,w^k_{t+1}\leftarrow w_t-\eta g_k \quad and \quad w_{t+1} \leftarrow \sum_{k=1}{K}\frac{n_k}{n}w_{t+1}^k.

  4. 每個客戶端都使用本地數據在當前模型上進行一步梯度下降,然後服務器對得到的模型進行加權平均。對於一個帶有nkn_k本地樣例的客戶端,每輪本地更新的數量由uk=EnkBu_k=E\frac{n_k}{B}來給出。

  5. 最近的工作表明在實踐中,充分參數化的神經網絡的損失曲面表現得很好,特別是不像以前所認爲的那麼容易出現糟糕的局部極小值。當我們從相同的隨機初始化開始兩個模型,然後再一次獨立訓練每一個不同的子數據集,我們發現普通的參數平均工作得很好:兩個模型的平均值,在完整MNIST訓練集上獲得的損失函數值明顯要低於單獨在兩個小數據集上進行訓練所獲得的最好模型。

  6. 實驗算法如下所示:
    算法代碼

五、實驗結果

  1. 對於每個任務,我們選擇一個適度大小的代理數據集,這樣我們就可以徹底地研究FedAvg算法的超參數。雖然每次單獨訓練運行相對要小,但我們爲這些實驗訓練了2000多個單獨的模型。
  2. 第一個任務是MNIST數字識別任務,它有兩種模型構建方式:(1)一個簡單雙隱層模型,每層有200個單元使用ReLu激活,我們將其定義爲MNIST 2NN;(2)一個CNN有着兩個5x5卷積層(第一個有着32個通道,第二個有着64個通道,每層後跟着2x2的最大池化),一個全連接層有着512個單元和ReLu激活,和一個最終的softmax輸出層。
  3. 爲了研究聯邦優化,我們也需要明確數據如何在客戶端上分佈。我們研究了兩種分割MNIST數據的方式:IID(數據被打亂,然後被分到100個客戶機,每個客戶機接收到600個樣本);Non-IID(根據數字標籤將數據排序,將其分成大小爲300的200個碎片化數據,然後指定100個客戶機中每個有2個碎片)。因此,讓我們探索一下我們的算法對蓋度non-IID數據的破壞程度。
  4. 第二個任務是語言模型,爲了研究聯邦優化,我們建立了一個數據集,它來自於莎士比亞全集。這個數據集是明顯不平衡的,一些角色只有幾句臺詞,而一些則有着大量臺詞。使用相同的訓練/測試分割,我們也可以形成一個平衡的IID版本的數據集。
  5. 在這個數據集上我們訓練了一個堆疊的字符級別的LSTM語言模型,在讀取一行中的每個字符後,可以預測下一個字符。該模型以一系列字符作爲輸入,並將每個字符嵌入到一個學習的8維空間中。嵌入的字符通過2LSTM層進行處理,每個LSTM層有着256個節點。最終第二個LSTM層的輸出被髮送到每個字符有一個節點的softmax輸出層。
  6. 提高並行性:我們首先設置C來測試實驗效果,它控制了多客戶端的並行度。爲了計算達到目標測試準確率所需要的通信輪數,我們爲每個參數設置的組合構造了一個學習曲線,使曲線可以單調性改進,然後計算曲線達到目標值的通信輪數。基於實驗中所展示的結果,在剩下的實驗中,我們固定C=0.1,在計算效率和收斂速度中取得了很好的平衡,對於固定B增加C並沒有一個很好的效果,而將B=無窮和B=10的輪數相比起來,可以發現有一個明顯的加速。下表展示了對於MNIST模型C值變化時的影響,每個表實體給出了2NN達到測試精度97%和CNN達到測試精度99%時所需要的通信輪數。下圖則展示了對於MNIST模型測試精度與通信輪數的變化曲線圖,我們將C設置爲0.1。
    提高並行性
    提高並行性
  7. 增加每個客戶機的計算量:我們將C固定爲0.1,每輪爲每個客戶機增加更多的計算,減少B,增加E或兩者都增加。每個客戶機每輪更新的預期數量由 給出,只要B足夠大,就能夠充分利用客戶端硬件上可用的並行性,降低它的計算時間基本上沒有成本,這是第一個調優的參數。當我們將在完全不同的數字對上訓練的模型參數平均時,平均提供了任何優勢。因此,我們認爲這爲這個方案的魯棒性提供了有力證據。下表展示了在MNIST模型上達到目標精度所需要的通信輪數的加速,下圖展示了Shakespeare LSTM模型的學習曲線。可以看到每輪添加更多的本地SGD更新然後再進行模型平均可以產生一個很好的加速效果。
    增加計算量
    增加計算量
  8. 我們推測,除了降低通信成本外,模型平均策略還產生了與dropout類型的正則化好處。我們主要關注泛化性能,但是FedAvg也能有效地優化訓練損失,甚至超過了測試集精度停滯不前的程度。下圖展示了FedAvg對於優化訓練損失的變化影響關係。
    優化訓練損失
  9. 當前模型參數僅通過初始化影響每個客戶端更新中執行的優化。當E趨近於無窮時,至少對於一個凸問題最終的初始條件應當是不相關的,無論初始化如何都會達到全局最小值。而對於一個非凸問題,只要初始化在同一個區域裏,算法就會收斂到相同的局部最小值。實驗結果表明:對於一些模型特別是在收斂階段的後期,每輪本地計算量的衰減(減小E或者增大B)可能是有用的,同樣衰減學習率也是有用的。對於E值較大的情況,收斂速度下降的幅度並不大。下圖上展示了在 Shakespeare LSTM問題上初始化訓練中E值的影響,下則展示了MNIST CNN的實驗影響。
    訓練次數變化

六、結論和未來工作

我們的實驗表明,聯邦學習具有重要的前景,因爲高質量的模型可以使用相對較少的通信輪數進行訓練。下一步的一個重要步驟是,在更大的數據集上對所提出的方法進行進一步的經驗評估,這些數據真正地捕捉到了現實世界問題的大規模分佈式本質。爲了保持算法探索範圍是可控的,我們限制自己以樸素SGD爲基礎。也可以研究我們的方法與其他優化算法,如AdaGrad和ADAM,以及模型結構的變化可以幫助優化,如dropout和批量規範化,這些都是未來工作的一個研究方向。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章