關於膠囊網絡(Capsule Net)的個人理解

        最近在跟蹤keras的contri版的更新時,發現了冒出了一個Capsule層。於是我百度+谷歌一頓操作猛如虎,才發現在很早之前,膠囊網絡的概念就提出了。但是限於膠囊網絡的performance並不是在各個數據集都是碾壓的情況,並且其計算量偏大,訓練時間偏長,所以並沒有被廣泛的運用和替換。但是在官方給出的測試結果來看,其實效果還是挺不錯的。

以上是原論文(https://arxiv.org/pdf/1710.09829.pdf)在mnist是數據集上的結果,結果是指錯誤率,可見效果還是有小幅提升的。在介紹膠囊網絡之前,給大家推薦一篇博客,寫得很好。https://www.jiqizhixin.com/articles/2017-11-05。同時,在keras上有開發者提供的capsule層源碼:https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/layers/capsule.py。本文將會以該源碼爲主,從各大博客的角度以及個人觀點介紹膠囊網絡。

另外,歡迎大家看看我之前介紹的bert模型 https://blog.csdn.net/weixin_42078618/article/details/94394348,兩者在層的設計上游異曲同工之處

一、什麼是膠囊網絡?

看看博客中的一張圖,簡單明瞭。

簡而言之就是,我有一個權重(叫做膠囊權重),它來來回回地跟輸入數據做計算,一邊計算一邊更新自己(路由算法)。

步驟如下:

(1)我先對膠囊層上一層的數據先做個預處理,處理前需要一個權重來爲處理做計算過度,這個權重就是W_{ij},處理成膠囊層需要的數據結構,即:u_{j}|_{i}(1號標)。所以這裏的u_{i}是指膠囊層上一層的輸出,即將囊層的輸入。真正的膠囊層計算從u_{j}|_{i}開始

(2)接着,我又定義了(或計算出)一組權重c_{i}_{j}(2號標),跟u_{j}|_{i}作矩陣乘法,得到s_{j}(3號標),重複此步。

(3)然後,對s_{j}做激活,得到v_{j}(4號標)

OVER!

看到這,我相信很多人還是一頭霧水,當時我也是一頭霧水。。。屮艸芔茻。。。於是我去讀了一遍keras的源碼和官方example,我才發現,其實,跟這個還是有點區別的。咱們再看一個keras官方的實現邏輯。

二、膠囊網絡的結構

首先,我們假設上一層的數據輸入是N x 72 x 64(N個句子,每個句子最多72個字,每個字的嵌入維度是64)的一個數據格式,膠囊數爲5,膠囊輸出維度爲10,路由次數爲4

(1)生成過度權重

# (1, 72, 5*10)
self.W = self.add_weight(name='capsule_kernel',
                                     shape=(1,
                                            input_dim_capsule,
                                            self.num_capsule *
                                            self.dim_capsule),
                                     initializer=self.initializer,
                                     regularizer=self.regularizer,
                                     constraint=self.constraint,
                                     trainable=True)
# (N, 72, 64) conv1d (1, 64, 50) --->>> (N, 72, 50)
u_hat_vectors = K.conv1d(inputs, self.W)

爲什麼說他是過度權重呢?

首先,該層傳入的數據,它不符合計算的數據格式,需要對它做一些調整。

比如:這層需要做一個3行5列的矩陣跟一個5行7列的矩陣做矩陣乘法,但是我輸入的數據是3行6列的,怎麼辦?這裏作者使用了卷積,把3行6列捲成3行5列。

怎麼卷?一維卷積!相當於用了50個channel=64的1x1的卷積核做卷積,這裏一定要想明白。相當於keras的Conv1D

那爲什麼不用Dense(全連接層)?也可以,當然可以。但是卷積的參數量小嘛,並且可以減少過擬合的風險,當然這裏我個人認爲用全連接也未嘗不可,可能效果沒有卷積好。

其次,如上所說,用全連接也可以達到這個效果,用池化行不行?也行!所以這步本質上只是一個過渡,過渡到膠囊核心計算所要求的數據格式。

到此,咱們的數據變成(N, 72, 50)的結構

(2)調整膠囊結構

# (N, 72, 5, 10)
u_hat_vectors = K.reshape(u_hat_vectors, (batch_size,
                                                  input_num_capsule,
                                                  self.num_capsule,
                                                  self.dim_capsule))
# (N, 72, 5, 10) --->>> (N, 5, 72, 10)
u_hat_vectors = K.permute_dimensions(u_hat_vectors, (0, 2, 1, 3))

這兩句怎麼理解呢?我們可以看成該數據有N個句子,每個句子最大72個字,每個字由5個膠囊單元組成,每個單元包含10個維度的抽象信息。

我們把該數據組合做一下轉置,轉成:N個句子,每個句子由5個膠囊單元組成,每個膠囊單元由72個字組成,每個字包含10個維度的抽象信息。

到這步,你會發現,這步操作跟自注意力的多頭機制是一模一樣的、一模一樣的、一模一樣的!瞭解過多頭自注意力的看官,接下里看着感覺就是炒多頭的舊飯!

(3)初始化一個膠囊權重

routing_weights = K.zeros_like(u_hat_vectors[:, :, :, 0])

這個權重會貫穿整個膠囊單元的計算

(4)膠囊單元計算(路由算法)

1)softmax調整權重

capsule_weights = K.softmax(routing_weights, 1)

這一步可以先跳過,待會回頭來看這一步

2)打分機制

# (N, 5, 72) * (N, 5, 72, 10) --->>> (N, 5, 5, 10)
outputs = K.batch_dot(capsule_weights, u_hat_vectors, [2, 2])

這一步,又跟多頭自注意力是一致的。唯一的區別是,多頭自注意力是互相打分,而這裏的膠囊計算是一個人給其他所有人打分!膠囊權重參數跟每一個channel的膠囊單元進行矩陣乘法,即:打分!

3)對膠囊個數所在維度進行求和

if K.ndim(outputs) == 4:
    # (N, 5, 10)
    outputs = K.sum(outputs, axis=1)

這是根據論文的原意進行的計算操作。這一步個人感覺,主要是爲了把數據繼續拉回符合下一步計算的格式。你說用K.mean可不可以,用K.max可不可以,我覺得似乎問題都不大。

4)L2正則。常規減少過擬合操作

outputs = K.l2_normalize(outputs, -1)

5)再一次打分機制!

# (N, 5, 5, 72)
routing_weights = K.batch_dot(outputs, u_hat_vectors, [2, 3])

這一次的矩陣乘法維度是[2, 3],這次設計打分是對膠囊單元的維度進行打分

看到這裏,你還敢說跟多頭自注意力機制不像?

6)繼續求和,拉回到初始膠囊單元的數據格式

if K.ndim(routing_weights) == 4:
    # (N, 5, 72)
    routing_weights = K.sum(routing_weights, axis=1)

到這,那簡直就是多頭自注意力模型的變體。他倆幾乎一模一樣。

我們先回顧一下,多頭自注意力模型的核心過程。

a、對輸入數據做拆分,分成QKV的前身

b、對QKV進行多頭拆分

c、Q跟K互相打分--->>>得到W

d、W跟V互相打分--->>>得到O

e、對O進行多頭合併

g、如此循環b——e,便可以形成深層多頭自注意力模型

咱們再回顧一下膠囊網絡的核心過程。

a、對輸入數據做一維卷積,得到膠囊輸入U的前身

b、路由權重W跟膠囊輸入U打分,求和得到輸出O

c、輸出O跟膠囊輸入U打分,求和更新路由權重W

d、重複b——c,便可以形成深層路由層

所以,仔細想想,真正實現出來的膠囊網絡跟論文以及上篇博客論述的膠囊網絡,還是有點差異的。這裏我把我從源碼讀出來的膠囊網絡跟大家分享一下

中間的核心計算層,他簡直就是自注意力層的一個縮小版!!!

r_w表示路由權重,c_w表示膠囊權重,他們都只是一個權重在不同階段的狀態罷了。

他跟自注意力層的區別有兩大點:

1、打分機制不同。attention一致都是自己跟自己打分- -,且同一維度跟同一維度打分;而膠囊網絡是設置一個可學習的單維權重(相對膠囊輸入數據而言是單維的),跟多維的輸入數據進行數據維度一對多的打分,在跟膠囊維度一對多的打分。

2、激活函數不同。膠囊網絡使用squash激活函數,attention使用gelu和relu。

三、其他

1、squash激活函數

公式如下:

簡單吧!給大家看看圖長什麼樣。

畫面真是不能太舒服。。。

這個公式主要由兩個式子構成,左邊和右邊。。。其中s是指向量

左邊:向量二範式的平方 / 向量二範式的平方 + 1,說白了就是x^{2} / 1 + x^{2}

右邊:單位向量嘛

所以這個公式既保證了數據在0-1之間,也保留了向量的方向(可以理解爲數據在另一個維度的特徵)。

2、損失函數

Margin loss

啊,這裏不展開了

 

四、總結

原來這就是膠囊網絡。看過bert源碼的,對這個膠囊結構應該是一見如故吧,它倆可真是太像了,設計理念是異曲同工的,各類預測結果還都讓人眼前一亮。

對於膠囊網絡,個人覺得,對於數據的結構操作,不太適合做圖像的backbone。因爲膠囊輸入需要的是N, row, col三維數據,這就必然導致三維圖像(N, row, col, chanel)需要做reshape操作,這一操作,講不準損失了什麼圖像的空間信息。

總之,膠囊網絡是讀完了,不造給位看官是否還有點霧水上頭啊。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章