SGANPose | 自對抗人體姿態估計網絡

Self Adversarial Training for Human Pose Estimation
Official Code: pytorch

1.出發點

由於人體的遮擋和擁擠等現象,現有的人體姿態估計網絡很難解決此類情況下的準確估計,且此類現象會導致網絡估計的關鍵點不符合正常的人體姿態,失去了人體固有的形態。比如下圖中第二行圖片所示,相較於第一行,很顯然有部分關節是違背事實的。作者希望即使在擁擠狀態下,網絡預測得到關鍵點也能夠符合關節所固有結構。基於此作者提出使用生成對抗的方式來解決這個問題。

2.自對抗網絡結構

與傳統的GAN模型類似,本文的模型分爲兩個網絡,生成器和鑑別器。第一個網絡生成器是一個卷積網絡,生成器經過前向計算,得到一組熱圖,它指示每個關鍵點的每個位置的置信度得分。第二個網絡鑑別器,具有與生成器相同的架構,但它將熱圖與RGB圖像一起編碼輸入,並將其解碼爲一組新的熱圖,以便區分真實的熱圖和虛假的熱圖。本文提出的自對抗網絡結果如下圖所示。在最終做關鍵點前向推理時,會將鑑別器從整體的結果中剔除。

3.生成器

生成器主要的作用是生成準確的人體關鍵點信息。當然作爲生成對抗中的一環,生成器最主要的功能就是能夠讓生成的關鍵點欺騙最終的鑑別器,使得鑑別器無法區分當前關鍵點熱圖是GT還是生成器生成的。因此,如下圖所示,訓練生成器時,其通過兩部分進行優化,分別爲反向傳播來自生成器的損耗Lmse和來自鑑別器的對抗損耗Ladv。

整體的loss如下所示,公式1的損失Lmse目的是使得生成器最終生成的人體關鍵點能夠更加接近標籤。公式2的對抗性損失Ladv,該對抗損失的目的是使得生成器最終生成的關鍵點符合更加合理的姿態。更直接的說,Ladv的目的是使得生成器生成的虛假熱度圖能夠儘可能的糊弄鑑別器,使其無法區分GT熱圖和虛假熱圖。生成對抗的過程就體現在這裏。最終利用公式3所示的損失來優化生成器。其中lamda是一個超參數。

4.鑑別器

鑑別器的目標是區分輸入進來的熱圖是GT還是生成器生成的虛假熱圖。鑑別器最終的訓練目標就是能夠把生成器生成的數據竟可能和GT區分出來。從而和生成器形成一個對抗博弈的過程。因此,如下圖所示,訓練鑑別器時,其通過兩部分進行優化,分別爲反向傳播來自鑑別器的損耗Lreal和來自鑑別器的損耗Lfake。

整體的loss如下所示,公式(4.1)表示將GT熱圖輸入鑑別器得到編碼後的新熱圖,並計算新熱圖和GT熱圖的距離,進行Lreal損失計算。公式(4.2)表示將生成器生成的虛假熱圖輸入鑑別器得到編碼後的新熱圖,並計算新熱圖和生成器生成的虛假熱圖之間的距離進行Lfake損失計算。正如前述提到過的,鑑別器的目的是儘可能的將虛假熱圖和GT熱圖區分開來,也就是說鑑別器希望GT熱圖輸入後的輸出重構熱圖儘可能和GT接近,希望虛假熱圖輸入後的輸出重構熱圖儘可能和虛假熱圖不同。從loss上來說就是希望Lreal越來越小,希望Lfake越來越大。基於此,鑑別器的loss如公式(4.3)所示。

上述公式中的kt是用來約束鑑別器的能力,通過公式(5)約束kt能夠使得網絡更容易訓練。正如許多論文中提到的那樣,GAN不穩定且難以訓練,因爲鑑別器過快收斂,導致網絡很容易崩潰,訓練出無效的生成器。鑑別器過快收斂,從loss來分析就是:Lfake小於Lreal,生成器生成的熱圖足夠真實以欺騙鑑別器。 此時,kt將增加,以使Lfake更具優勢,從而使得鑑別器進行更多的訓練才能識別生成的熱圖。它在Lfake上加速訓練的比例取決於鑑別器落在與生成器的差距。當Lfake大於Lreal時原理類似。

對公式4進行解讀:
公式4.1 輸入爲原始RGB圖像X,GT熱度圖C。計算的Lreal表示鑑別器產生的結果和GT熱度圖之間的差別。
公式4.2輸入爲原始RGB圖像X,生成器產生的熱度圖C^。計算Lfake表示鑑別器產生的結果和生成器產生熱度圖之間的差別。
公式4.3表示最終整個公式4,也就是鑑別器的loss的目的是最小化Lreal和Lfake,即整個優化過程要求Lreal小且Lfake大,直白的來說就是要求當輸入爲GT熱度圖時,鑑別器產生儘可能和GT相同的結果。當輸入爲生成器產生的熱度圖時,鑑別器產生儘可能和生成器不同的結果。如,如果右膝蓋的信心在左膝蓋附近很高,則訓練有素的鑑別器將產生右膝蓋的熱圖,該熱圖在左膝蓋的位置具有較大的誤差。由於鑑別器就像評論家一樣, 它在輸入熱圖上提供了詳細的“註釋”,並建議熱圖中的哪些部分未產生真實姿勢。最終整個誤差會在公式2中體現出來。而公式二會指導生成器進行進化,使得最終的生成器更好,降低整個誤差。

##### 5.算法整體流程

整體算法每一個迭代過程如下:

1.將GT熱度圖C,原始圖像X輸入到鑑別器,計算鑑別器的前向結果。爲D(X,C)。同時計算鑑別器的loss,公式4.1,Lreal。
2.將原始圖像X輸入到生成器,計算生成器的前向結果C^。同時計算生成器loss,公式1,Lmse。
3.將虛假熱度圖C,原始圖像X輸入到鑑別器,計算鑑別器的前向結果。爲D(X,C)。同時計算鑑別器的loss,公式4.2,Lfake。(累計Lreal和Lfake梯度值,並更新鑑別器參數,公式4.3)。
4.有了虛假熱度圖C和D(X,C),利用公式2計算對抗loss,Ladv,並更新生成器。

##### 6.結果展示
作者在LSP和MPII兩個人體關鍵點數據集上對上述自對抗網絡進行了結果分析,從下表可以看出,利用對抗生成的方式能夠有效提升模型效果,且不會增加推理時間。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章