解密prompt系列24. RLHF新方案之訓練策略:SLiC-HF & DPO & RRHF & RSO

去年我們梳理過OpenAI,Anthropic和DeepMind出品的經典RLHF論文。今年我們會針對經典RLHF算法存在的不穩定,成本高,效率低等問題討論一些新的方案。不熟悉RLHF的同學建議先看這裏哦解密Prompt7. 偏好對齊RLHF-OpenAI·DeepMind·Anthropic對比分析

RLHF算法當前存在的一些問題有

  1. RL的偏好樣本的人工標註成本太高,效率低,容易存在標註偏好不一致的問題
  2. RLHF屬於online訓練策略,在訓練過程中需要讓模型進行解碼,時間成本高訓練效率低
  3. RLHF在訓練過程中需要同時部署Reward模型和SFT模型和更新後的模型,顯存佔用高訓練成本高
  4. RLHF需要兩階段的訓練,需要先訓練reward模型,再使用reward模型更新SFT模S型

這一章我們先聊聊訓練策略的新方案。用新方案而不是優化或者改良,因爲平替們的效果需要更長時間的驗證。

SLiC-HF

  • SLiC-HF: Sequence Likelihood Calibration with Human Feedback
  • CALIBRATING SEQUENCE LIKELIHOOD IMPROVES CONDITIONAL LANGUAGE GENERATION

要說SLiC-HF,肯定要先說下前置的Calibartion Sequence likelihood(SLiC)的對齊技術,畢竟上面這兩篇論文的部分作者都相同,思路自然是一脈相承。

SLiC

SLiC對標SFT,也是post-training的指令對齊方案。方案針對指令微調階段使用MLE也就是next token prediction帶來的稀疏訓練問題。因爲給定context,是有無數種output可能的。而微調階段只使用唯一的答案進行訓練,導致模型訓練不充分。一個明顯的現象就是序列的解碼概率越高,並不意味着生成序列的質量越好,這意味着生成序列其實是未修正的(uncalibrated)

SLiC的思路有些類似半監督。也就是標註數據有限,導致模型參數更新的空間有限的情況下,我們可以使用半監督的平滑性和一致性原則,既和標註樣本相似的樣本label相同,反之不同的思路,使用無標註樣本對模型進行更新

那我們把半監督的思路放到文本生成:

第一步.先使用SFT對齊後的模型,針對標註樣本,每個樣本生成m個推理候選結果,這些就是半監督中的未標註樣本

第二步.使用無監督樣本進行對比訓練,核心就是訓練模型對和標註答案更相似的候選樣本給予更高的解碼概率,反之更低

這裏訓練就有兩個細節

  1. 序列相似如何定義?這裏沒有引入新的向量模型,直接使用大模型解碼輸出層的向量表徵(seq * hidden)和標註結果的向量表徵來計算cosine相似度,相似度計算參考了BertScore的F1值。並且這裏對序列進行了切分,分別計算span=1,2,4,8等不同長度的F1值,再進行聚合。

  1. 損失函數如何定義?論文嘗試了以下4種不同的對比損失函數,主要差異在pair-wise還是list-wise,擬合相似度的相對排序(i-j),還是絕對打分(P(yi|x)-P(yj|x))的高低。消融實驗顯示第一個Rank Loss的效果最好。也就是從所有解碼生成的候選中隨機採樣兩個,以上F1更高的爲正樣本,反之爲負樣本。計算解碼概率的Hinge-Loss

這裏論文同樣加入了正則項,避免模型過度偏離原始SFT對齊的模型,分別嘗試了KL和MLE兩種不同的正則。消融實驗顯示KL正則項的效果更好。
所以綜上SLiC使用了無監督的思路,用對比學習來進行對齊。下面我們來看如何使用SLiC來對齊人類偏好

SLiC-HF

偏好樣本

首先SLiC-HF用的是offline的訓練方案,所以先說下偏好樣本是如何構建的。論文嘗試了Direct和Sample and Rank兩種樣本構建方案。

Direct方案就是直接使用Reddit摘要數據集中人工標註的正負偏好樣本作爲\(y^+,y^-\),優點是成本低,缺點是這裏的解碼結果可能和SFT模型的解碼分佈存在偏差。

Sample and Rank,也就是先使用以上偏好數據訓練Reward模型,論文嘗試了兩種方案,一個是絕對偏好,模型預測Good/Bad使用解碼概率作爲label。另一個是相對偏好,也就是模型學習兩個摘要之間的相對好壞。

之後使用SFT模型隨機解碼(temperature=0.7)生成的8個解碼候選,使用以上模型打分或排序後,隨機採樣8個正負樣本對。

效果上Sample and Rank要優於Direct,但如果Driect部分是直接使用SFT模型生成候選再人工標註的話,其實結果可能也不差。

損失函數

已經有了正負樣本對,那其實只需要用到上面的對比損失函數了,不需要使用半監督了。不過這裏的正則器沒有選用KL,而是直接使用SFT樣本的MLE來防止模型能力衰減。最終的損失函數如下

除了Offline的樣本構建訓練效率更高之外,SLiC-HF直接使用序列概率表徵偏好,因此不需要使用reward模型,同時對比來自樣本而非來自模型,因此也不再需要使用凍結參數的SFT模型。訓練過程內容中只有一個SFT模型進行梯度更新。

DPO

DPO和SLiC同樣是基於offline的正負偏好樣本對,通過對比學習來進行偏好對齊。DPO的偏好樣本標註是直接基於SFT模型生成候選,然後人工標註得到正負(win,loss)樣本對,然後直接使用損失函數進行擬合,不訓練reward模型。不過二者的對比損失函數不同,DPO的損失函數如下

以上\(\pi\)是模型解碼輸出層每個token
的輸出概率logp求和,\(\theta\)是參與梯度更新的模型,ref是SFT對齊後的模型參數作爲基準參數被凍結。

所以簡單直觀的理解也就是DPO的損失函數,讓模型對偏好樣本的解碼概率相比ref升高,讓模型對負樣本的解碼概率相比ref下降。和Triplet Loss的對比損失函數的思路有些相似。

我們和SLiC-HF做下對比,首先SLiC是hinge-loss(maximum-margin),DPO不是。其次SLiC是正負樣本直接對比,DPO是正負樣本概率分別和基準模型(SFT模型)進行對比,二者的差異有些類似simases和triplet loss,只不過DPO的錨點不是錨點樣本而是基準模型。所以模型既需要擬合相對偏好,也需要保證絕對分佈不會答覆偏離原始SFT模型。在後面的一些對比論文中普遍結論是DPO的損失函數更優,SLiC的對比函數會導致一些reward hacking

論文還進一步從梯度計算的角度進行了闡述,如果上述損失函數對\(\theta\)求導。會得到以下公式

其中\(\hat{r_{\theta}}(x,y)=\beta log(\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)})\)是DPO的核心,既對齊模型的輸出層的概率偏離原始SFT模型的幅度能隱式表徵偏好,作爲 pseudo Reward來進行模型對齊。正負樣本差異越大越多更新幅度越大,梯度方向是提高偏好樣本的解碼概率,降低負樣本的解碼概率。

RRHF

RRHF同樣是offline構建正負樣本對,再採用對比學習進行偏好對齊的方案,那這裏我們只看RRHF和SLiC的差異點。
其一是RRHF使用了長度歸一化的序列概率來表徵偏好,SLiC直接使用瞭解碼概率

其二是SLiC使用了Hinge-Loss,而RRHF是直接擬合正負樣本的概率差

其三是正負樣本的構建方案,SLiC是基於SFT模型進行隨機解碼生成候選,並基於Reward模型離線構建正負樣本,而RRHF的候選採樣方案還對比了beam-search,diversity-beam-search,以及Iterate-beam-search,也就是每訓練一個epoch基於微調後的模型重新生成一波候選。Iterate-beam-search的採樣方案會有一些效果提升,考慮生成樣本會隨分佈修正而逐漸優化,可以覆蓋更多的分佈空間。以及Iterate-beam-search其實和PPO在線解碼進行模型更新的方案更加相似,但相對效率更高。

三合一大禮包- RSO

STATISTICAL REJECTION SAMPLING IMPROVES PREFERENCE OPTIMIZATION

RSO方案融合了以上三者,主要是DPO和SLiC,分別對損失函數和偏好樣本對的構建方式進行了改良。先說損失函數,RSO把SLiC的Hinge-loss加入到DPO的sigmoid-norm損失函數中,得到了如下的hinge-norm損失函數

再有是偏好樣本構建,RSO指出既然以上對比函數的目標是擬合最優的Policy,那理論上偏好樣本對也應該從\(\pi^*\)來構建。近似於以上RRHF的Iterate-beam-search的最後一個Iterate的樣本分佈。但\(\pi^*\)還沒訓練出來要如何拿到它的對比樣本呢?

這裏RSO提出可以採用從\(\pi_{SFT}\)中拒絕採樣來近似\(\pi_{r}\)的分佈,對比SLiC的SFT-sample-rank,稱之爲RSO-Sample-Rank。具體構建方式還是從SFT生成多個解碼候選,並使用訓練的Reward模型對每個候選進行打分,接着進行拒絕採樣。

首先拒絕採樣使用g(x)擬合f(x), 計算一個常數C,使得\(c*g(x)>=f(x)\)。則採樣過程是從g(x)中採樣,當隨機變量\(U\sim(0,1)<=\frac{f(x)}{c*g(x)}\)則保留樣本,反之拒絕。

這裏g(x)就是SFT模型\(\pi_{sft}\),f(x)是最終對齊的模型\(\pi_{r_{\tau}}\),理論上\(m*\pi_{sft}>=\pi_{r_{\tau}}\),這樣當\(U<= \frac{\pi_{r_{\tau}}}{m*\pi_{sft}}\)我們保留樣本,但因爲這裏的的\(\pi_{r_{\tau}}\)並無法獲得,因此我們用DPO中推導的Policy和reward的關係

爲了diff掉正則項Z,論文使用所有隨機解碼樣本的最大reward的(x,y)來作爲常數C的估計。

最終得到的拒絕採樣的代碼如下

截圖\_選擇區域\_20240211093757

效果上論文對比了DPO,SLiC,RSO,以及不同損失函數,不同採樣方案的效果差異。整體上採樣帶來的收益是更爲顯著,DPO的損失函數上加不加hinge差異並不大,但都會優於SLiC的直接對比損失函數。

截圖\_選擇區域\_20240211093937

想看更全的大模型相關論文梳理·微調及預訓練數據和框架·AIGC應用,移步Github >> DecryPrompt

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章