統計學習方法之EM算法及其推廣

轉至:http://www.hankcs.com/ml/em-algorithm-and-its-generalization.html

本文是《統計學習方法》第九章的筆記,註解了原著的部分公式推導,補充了另一個經典的雙硬幣模型,並且註釋了一份數十行的EM算法Python簡明實現。

如果概率模型的變量都是觀測變量(數據中可見的變量),則可以直接用極大似然估計,或者用貝葉斯估計模型參數。但是,當模型含有隱變量(數據中看不到的變量)時,就不能簡單地使用這些估計方法,而應該使用含有隱變量的概率模型參數的極大似然估計法,也即EM算法。EM算法.png

EM算法的引入

引入EM算法有兩個常見的例子,一個是三硬幣模型,一個是雙硬幣模型。《統計學習方法》介紹的是三硬幣模型,本文引述該模型後,再補充雙硬幣模型。

三硬幣模型

有ABC三枚硬幣,單次投擲出現正面的概率分別爲π、p、q。利用這三枚硬幣進行如下實驗:

1、第一次先投擲A,若出現正面則投擲B,否則投擲C

2、記錄第二次投擲的硬幣出現的結果,正面記作1,反面記作0

獨立重複1和2十次,產生如下觀測結果:

1 1 0 1 0 0 1 0 1 1

假設只能觀測到擲硬幣的最終結果,無法觀測第一次投擲的是哪一枚硬幣,求π、p、q,即三硬幣模型的參數。

記模型參數爲θ=(π,p,q),無法觀測的第一次投擲的硬幣爲隨機變量z,可以觀測的第二次投擲的硬幣爲隨機變量y,則觀測數據的似然函數爲:

屏幕快照 2016-05-29 下午12.55.28.png

這是個一目瞭然的式子,兩個事件,第一個事件選出那枚看不到的硬幣,第二個事件利用這枚硬幣進行一次投擲。利用硬幣結果只可能是0或1這個特性,可以將這個式子展開爲:

屏幕快照 2016-05-29 下午3.23.13.png

y的觀測序列給定了,怎麼找出一個模型參數,使得這個序列的概率(似然函數的值)最大呢,也就是求模型參數的極大似然估計:

屏幕快照 2016-05-29 下午3.27.34.png

這個問題我認爲是個NP問題,一方面,給定模型參數,可以在多項式時間求出似然函數的值,然而模型參數的組合是無窮的,誰也不知道它是否是最優的。

EM算法簡單理解

EM算法是求解這個問題的一種迭代算法(我認爲並非精確算法,而是近似算法),它有3步:

初始化:選取模型參數的初值:屏幕快照 2016-05-29 下午3.33.10.png,循環如下兩步迭代

E步:計算在當前迭代的模型參數下,觀測數據y來自硬幣B的概率:

屏幕快照 2016-05-29 下午3.36.22.png

這個式子也是一目瞭然的,分子代表選定B並進行一次投擲試驗,分母代表選定B或C並進行一次投擲試驗,兩個一除就得到試驗結果來自B的概率。

M步:估算下一個迭代的新的模型估算值:

屏幕快照 2016-05-29 下午3.37.46.png

這個也好說,把這n個{試驗結果來自B的概率}求和得到期望,平均後,得到B出正面的似然估計,同理有p和q。

重複迭代,直到收斂爲止。

這個模型中,觀測數據Y和隱數據Z組合在一起稱爲完全數據,單獨的觀測數據Y稱爲不完全數據。在隱數據未知的情況,無法直接估計Y的概率分佈。但當模型概率給定時,就可以估計Y的條件概率分佈了。

Y的條件概率分佈估計出來後有什麼用呢?利用Y的條件概率分佈,又可以更新模型參數……那問題來了,爲什麼要這麼做,這麼做能否找到最優解,原理是什麼?

帶着這些問題啃書本稍微有趣一些,在探索這個問題之前,有必要規範地描述EM算法,並引入一些正規的符號和定義:

EM算法的標準定義

輸入:觀測變量數據Y,隱變量數據Z,聯合分佈屏幕快照 2016-05-29 下午5.08.31.png,條件分佈屏幕快照 2016-05-29 下午5.09.23.png

輸出:模型參數θ

(1)  選擇參數的初值屏幕快照 2016-05-29 下午5.26.08.png,開始迭代;

(2)  E步:記屏幕快照 2016-05-29 下午5.11.10.png爲第i次迭代參數θ的估計值,在第i+1次迭代的E步,計算

屏幕快照 2016-05-29 下午5.31.04.png

 

這裏,屏幕快照 2016-05-29 下午5.31.50.png是在給定觀測數據Y和當前的參數估計屏幕快照 2016-05-29 下午5.32.32.png下隱變量數據z的條件概率分佈;

(3) M步:求使屏幕快照 2016-05-29 下午5.33.09.png極大化的θ,確定第i+1次迭代的參數的估計值屏幕快照 2016-05-29 下午5.34.06.png

屏幕快照 2016-05-29 下午5.35.08.png

(4)重複第(2)步和第(3)步,直到收斂。

屏幕快照 2016-05-29 下午5.31.04.png的函數屏幕快照 2016-05-29 下午5.33.09.png是EM算法的核心,稱爲Q函數(Q function)。

定義(Q函數)完全數據的對數似然函數屏幕快照 2016-05-29 下午5.38.46.png關於在給定觀測數據Y和當前參數屏幕快照 2016-05-29 下午5.11.10.png下對未觀測數據Z的條件概率分佈屏幕快照 2016-05-29 下午5.40.08.png的期望稱爲Q函數,即

屏幕快照 2016-05-29 下午5.41.00.png

下面關於EM算法作幾點說明:

步驟(1)參數的初值可以任意選擇,但需注意EM算法對初值是敏感的。

步驟(2)E步求屏幕快照 2016-05-29 下午5.33.09.png。Q函數式中Z是未觀測數據,Y是觀測數據。注意,屏幕快照 2016-05-29 下午5.33.09.png的第1個變元表示要極大化的參數,第2個變元表示參數的當前估計值。每次迭代實際在求Q函數及其極大。

步驟(3)M步求屏幕快照 2016-05-30 下午8.27.14.png的極大化,得到屏幕快照 2016-05-30 下午8.27.33.png,完成一次迭代屏幕快照 2016-05-30 下午8.27.47.png。後面將證明每次迭代使似然函數增大或達到局部極值。

步驟(4)給出停止迭代的條件,一般是對較小的正數屏幕快照 2016-05-29 下午5.44.47.png,若滿足

屏幕快照 2016-05-29 下午5.45.27.png

則停止迭代。

EM算法的導出

看完了冗長的標準定義,認識了一點也不Q的Q函數,終於可以瞭解EM算法是怎麼來的了。

尋找模型參數的目標(或稱標準)是找到的參數使觀測數據的似然函數最大,一般用對數似然函數取代似然函數,這樣可以把連乘變爲累加,方便優化,也就是極大化

屏幕快照 2016-05-29 下午5.59.36.png

這個式子裏面有未知的隱變量Z,無法直接優化。

但是如同在“EM算法簡單理解”中看到那樣,給定模型參數,就可以估計Y的條件概率(後驗概率,已經有Z這個結果,求原因Y的概率)。所以我們就挑一個模型參數的初值,也就是EM算法的第1步。

有了初值,就可以代入似然函數得到一個值,但這個值不一定是最大的,我們想要更大,所以需要調整參數,這也是EM算法爲什麼要迭代的原因。

事實上,EM算法是通過迭代逐步近似極大化似然函數的。假設在第i次迭代後屏幕快照 2016-05-29 下午6.12.05.png的估計值是屏幕快照 2016-05-29 下午6.12.31.png。我們希望新估計值屏幕快照 2016-05-29 下午6.12.05.png能使屏幕快照 2016-05-29 下午6.06.06.png增加,即屏幕快照 2016-05-29 下午6.13.28.png,並逐步達到極大值。爲此,考慮兩者的差:

屏幕快照 2016-05-29 下午6.14.05.png

利用Jensen不等式(Jensen inequality)

屏幕快照 2016-05-29 下午9.23.30.png

得到其下界:

屏幕快照 2016-05-29 下午9.09.35.png

式子有點長,而且用了些技巧,慢慢看。首先第一行的屏幕快照 2016-05-29 下午9.31.06.png是人爲加上去的,先乘以這一項再除以這一項得到的依然是屏幕快照 2016-05-29 下午9.18.56.png,然後第二行就利用了琴生不等式,將log運算符移入求和項中。但屏幕快照 2016-05-29 下午9.31.06.png爲何變成了屏幕快照 2016-05-29 下午9.12.24.png呢?我認爲這是李航博士的筆誤,第一行就應該分子分母同時乘以屏幕快照 2016-05-29 下午9.12.24.png的。參考普林斯頓大學的講義《COS 424- Interacting with Data.pdf》:

屏幕快照 2016-05-29 下午9.44.19.png

應該乘以Z的概率分佈,也就是屏幕快照 2016-05-29 下午9.12.24.png

然後第三行的變換其實很簡單,將log拆成log乘以屏幕快照 2016-05-29 下午9.12.24.png對Z求和的形式,再將每一項中的-log跟前一個求和中的每一項中的log合併,log的減法變成除法就得到最終結果了。

好了,言歸正傳,將屏幕快照 2016-05-29 下午10.02.38.png移到等號右邊去,得到一個函數,取個名字:

屏幕快照 2016-05-29 下午10.15.18.png

那麼就有

屏幕快照 2016-05-29 下午10.27.20.png

得到了屏幕快照 2016-05-29 下午9.18.56.png的一個下界,如果將θ 屏幕快照 2016-05-29 下午10.34.46.png,代入屏幕快照 2016-05-29 下午10.15.18.png,我們會發現,在模型參數一致的情況下,log項中的分子分母都是同一個(Y,Z)的聯合分佈,所以分子分母相等,後面這個求和項等於0,直接得到:

屏幕快照 2016-05-30 上午10.37.19.png

說明屏幕快照 2016-05-29 下午10.27.20.png等號成立的條件是θ 屏幕快照 2016-05-29 下午10.34.46.png,換句話說只要θ 不等於屏幕快照 2016-05-29 下午10.34.46.png,就一定能讓屏幕快照 2016-05-29 下午9.18.56.png變大一點。換句話說,任何能使屏幕快照 2016-05-30 上午10.42.00.png增大的屏幕快照 2016-05-30 上午10.42.48.png,都能使屏幕快照 2016-05-29 下午9.18.56.png增大(通過優化對數似然函數的下界來間接優化它)。爲了儘可能顯著地增大屏幕快照 2016-05-29 下午9.18.56.png,需要選擇屏幕快照 2016-05-30 上午10.45.17.png使得屏幕快照 2016-05-30 上午10.42.00.png達到極大:

屏幕快照 2016-05-30 上午10.46.02.png

現在來推導屏幕快照 2016-05-30 上午10.45.17.png的表達式,去掉上式中所有與屏幕快照 2016-05-30 上午10.42.48.png無關的常數項,有:

屏幕快照 2016-05-30 上午10.47.47.png

推到最後發現屏幕快照 2016-05-30 上午10.45.17.png等於最大化Q函數時的參數,也就是M步執行的那樣。

EM算法是通過不斷求解下界的極大化逼近求解對數似然函數極大化的算法。如圖:

屏幕快照 2016-05-30 上午10.58.15.png

在一個迭代內保證對數似然函數的增加的,迭代結束時無法保證對數似然函數是最大的。也就是說,EM算法不能保證找到全局最優值。嚴密的證明請接着看下一節。

EM算法的收斂性

對數似然函數單調遞增定理 設屏幕快照 2016-05-30 上午11.05.32.png爲觀測數據的似然函數,屏幕快照 2016-05-30 上午11.06.08.png爲EM算法得到的參數估計序列,屏幕快照 2016-05-30 上午11.06.41.png爲對應的似然函數序列,則屏幕快照 2016-05-30 上午11.07.06.png是單調遞增的,即

屏幕快照 2016-05-30 上午11.07.30.png

證明參考《統計學習方法》161頁。

收斂性定理 屏幕快照 2016-05-30 上午11.13.40.png爲觀測數據的對數似然函數,屏幕快照 2016-05-30 上午11.06.08.png爲EM算法得到的參數估計序列,屏幕快照 2016-05-30 上午11.14.31.png爲對應的對數似然函數序列。

(1)如果屏幕快照 2016-05-30 上午11.15.03.png有上界,則屏幕快照 2016-05-30 上午11.15.51.png收斂到某一值屏幕快照 2016-05-30 上午11.16.09.png;

(2)在函數屏幕快照 2016-05-30 上午11.16.40.png屏幕快照 2016-05-30 上午11.16.59.png滿足一定條件下,由EM算法得到的參數估計序列屏幕快照 2016-05-30 上午11.17.43.png的收斂值屏幕快照 2016-05-30 上午11.18.43.png屏幕快照 2016-05-30 上午11.19.12.png的穩定點。

證明依然參考《統計學習方法》162頁,事實上,連原著都省略了(2)的證明。

既然EM算法不能保證找到全局最優解,而且初值會影響最終結果,那麼實際應用中有什麼技巧呢?答案是多選幾個初值,擇優錄取。

原著接下來介紹了EM算法在高斯混合模型中的應用,以及EM算法的推廣。這在超出了我目前對理論的需求,所以暫時打住,進入實踐環節。

EM算法的簡明實現

當然是教學用的簡明實現了,這份實現是針對雙硬幣模型的。

雙硬幣模型

假設有兩枚硬幣A、B,以相同的概率隨機選擇一個硬幣,進行如下的拋硬幣實驗:共做5次實驗,每次實驗獨立的拋十次,結果如圖中a所示,例如某次實驗產生了H、T、T、T、H、H、T、H、T、H,H代表正面朝上。

假設試驗數據記錄員可能是實習生,業務不一定熟悉,造成a和b兩種情況

a表示實習生記錄了詳細的試驗數據,我們可以觀測到試驗數據中每次選擇的是A還是B

b表示實習生忘了記錄每次試驗選擇的是A還是B,我們無法觀測實驗數據中選擇的硬幣是哪個

問在兩種情況下分別如何估計兩個硬幣正面出現的概率?

EM算法.png

a情況相信大家都很熟悉,既然能觀測到試驗數據是哪枚硬幣產生的,就可以統計正反面的出現次數,直接利用最大似然估計即可。

b情況就無法直接進行最大似然估計了,只能用EM算法,接下來引用nipunbatra博主的簡明EM算法Python實現。

建立數據集

針對這個問題,首先採集數據,用1表示H(正面),0表示T(反面):

  1. # 硬幣投擲結果觀測序列
  2. observations = np.array([[1, 0, 0, 0, 1, 1, 0, 1, 0, 1],
  3.                          [1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
  4.                          [1, 0, 1, 1, 1, 1, 1, 0, 1, 1],
  5.                          [1, 0, 1, 0, 0, 0, 1, 1, 0, 0],
  6.                          [0, 1, 1, 1, 0, 1, 1, 1, 0, 1]])

初始化

選定初值,比如

屏幕快照 2016-05-30 下午4.49.49.png

第一個迭代的E步

拋硬幣是一個二項分佈,可以用scipy中的binom來計算。對於第一行數據,正反面各有5次,所以:

  1. coin_A_pmf_observation_1 = stats.binom.pmf(5,10,0.6)

輸出

  1. 0.20065812480000034

類似地,可以計算第一行數據由B生成的概率:

  1. coin_B_pmf_observation_1 = stats.binom.pmf(5,10,0.5)

輸出:

  1. 0.24609375000000025

將兩個概率正規化,得到數據來自硬幣A的概率:

  1. normalized_coin_A_pmf_observation_1 = coin_A_pmf_observation_1/(coin_A_pmf_observation_1+coin_B_pmf_observation_1)
  2. print "%0.2f" %normalized_coin_A_pmf_observation_1

這個值類似於三硬幣模型中的μ,只不過多了一個下標,代表是第幾行數據(數據集由5行構成)。同理,可以算出剩下的4行數據的μ。

有了μ,就可以估計數據中AB分別產生正反面的次數了。μ代表數據來自硬幣A的概率的估計,將它乘上正面的總數,得到正面來自硬幣A的總數,同理有反面,同理有B的正反面。

  1. # 更新在當前參數下A、B硬幣產生的正反面次數
  2. counts['A']['H'] += weight_A * num_heads
  3. counts['A']['T'] += weight_A * num_tails
  4. counts['B']['H'] += weight_B * num_heads
  5. counts['B']['T'] += weight_B * num_tails

第一個迭代的M步

當前模型參數下,AB分別產生正反面的次數估計出來了,就可以計算新的模型參數了:

  1. new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
  2. new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])

對於第一個迭代,新的模型參數分別爲:

屏幕快照 2016-05-30 下午5.22.28.png

與論文是一致的,於是就可以整理一下,給出EM算法單個迭代的代碼:

  1. def em_single(priors, observations):
  2.     """
  3.     EM算法單次迭代
  4.     Arguments
  5.     ---------
  6.     priors : [theta_A, theta_B]
  7.     observations : [m X n matrix]
  8.  
  9.     Returns
  10.     --------
  11.     new_priors: [new_theta_A, new_theta_B]
  12.     :param priors:
  13.     :param observations:
  14.     :return:
  15.     """
  16.     counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
  17.     theta_A = priors[0]
  18.     theta_B = priors[1]
  19.     # E step
  20.     for observation in observations:
  21.         len_observation = len(observation)
  22.         num_heads = observation.sum()
  23.         num_tails = len_observation - num_heads
  24.         contribution_A = stats.binom.pmf(num_heads, len_observation, theta_A)
  25.         contribution_B = stats.binom.pmf(num_heads, len_observation, theta_B)   # 兩個二項分佈
  26.         weight_A = contribution_A / (contribution_A + contribution_B)
  27.         weight_B = contribution_B / (contribution_A + contribution_B)
  28.         # 更新在當前參數下A、B硬幣產生的正反面次數
  29.         counts['A']['H'] += weight_A * num_heads
  30.         counts['A']['T'] += weight_A * num_tails
  31.         counts['B']['H'] += weight_B * num_heads
  32.         counts['B']['T'] += weight_B * num_tails
  33.     # M step
  34.     new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
  35.     new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
  36.     return [new_theta_A, new_theta_B]

EM算法主循環

給定循環的兩個終止條件:模型參數變化小於閾值;循環達到最大次數,就可以寫出EM算法的主循環了:

  1. def em(observations, prior, tol=1e-6, iterations=10000):
  2.     """
  3.     EM算法
  4.     :param observations: 觀測數據
  5.     :param prior: 模型初值
  6.     :param tol: 迭代結束閾值
  7.     :param iterations: 最大迭代次數
  8.     :return: 局部最優的模型參數
  9.     """
  10.     import math
  11.     iteration = 0
  12.     while iteration < iterations:
  13.         new_prior = em_single(prior, observations)
  14.         delta_change = np.abs(prior[0] - new_prior[0])
  15.         if delta_change < tol:
  16.             break
  17.         else:
  18.             prior = new_prior
  19.             iteration += 1
  20.     return [new_prior, iteration]

調用

給定數據集和初值,就可以調用EM算法了:

  1. print em(observations, [0.6, 0.5])

得到

  1. [[0.79678875938310978, 0.51958393567528027], 14]

與論文中的結果是一致的(我們多迭代了4次,畢竟我們不清楚論文作者設置的終止條件):

屏幕快照 2016-05-30 下午5.29.24.png

我們可以改變初值,試驗初值對EM算法的影響。

  1. em(observations, [0.5,0.6])

得到

  1. [[0.51958345063012845, 0.79678895444393927], 15]

看來EM算法還是很健壯的

如果把初值設爲相等會怎樣?

調用

  1. em(observations, [0.3,0.3])

得到

  1. [[0.66000000000000003, 0.66000000000000003], 1]

這顯然是不是個好主意,再試試很極端的情況:

  1. em(observations, [0.9999,0.00000001])

得到

  1. [[0.79678850504581944, 0.51958235686544463], 13]

可見EM算法仍然很聰明。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章