《DLOW:Domain Flow for Adaptation and Generalization》論文解析

今天說的這篇文章,也是用來解決遷移學習問題的。遷移學習要解決一個什麼問題呢?就是要把模型在source域(源域)學習到的知識,用到target域(目標域)裏。

DLOW這篇文章主要提出了兩點:1、可以把source域的數據遷移成中間域,中間域也就是介於source和target之間的域。  2、訓練的時候如果有多個target域的話,DLOW可以生成網絡沒有見過的數據風格。

那麼接下來介紹一下算法原理:

1、CycleGAN

作爲本算法的基礎,cycleGAN至關重要。具體的介紹請看我另一篇博客:

https://blog.csdn.net/wenqiwenqi123/article/details/105123491

在這裏複習一下,cycleGAN主要由兩個loss組成:

其中xs代表源域數據,xt代表目標域數據。Gst爲source到target的生成器,Dt爲判別目標域真實數據和生成數據的判別器。反之相同。

 

2、定義中間域

設中間域爲M(z),z爲[0,1]的一個變量,跟與source和target的聯繫有關。換句話說,z=0的時候,M(z)就是source,當z=1的時候,M(z)就是target。

如下圖所示,其實從S到T的路徑有許多,但是我們希望我們能找到一條最近的路(貼着地平線過去的那條,紅線)。

因此我們得到了如下公式:

其中dist爲某種距離表示,在本算法中使用了公式一的距離(cycleGAN)。

因此把3公式化簡一下,得到了loss函數:

 

3、DLOW模型

綜上,我們現在有了Source域的數據,和Z=[0,1],Gst的目的是得到中間域而不再是得到Target域。

因此有

Adversarial Loss:GAN肯定會有對抗損失。這裏定義判別器Ds(x)是區分M(z)和S,而Dt(x)是區分M(z)和T。因此對抗損失可以寫爲:

用上面的損失來代入dist的話,則得到:

Image Cycle Consistency Loss:cycleGAN的循環一致損失:

其中Gts是從target到M的生成器:

總損失:

 

實現:

整個網絡架構如下所示:

z和S一起作爲輸入輸進生成器Gst,對z作反捲積得到(1,16,1,1)的向量,同時對z進行採樣:

這樣的話,z會在一開始的時候趨向於較小,隨着訓練逐漸加大,這樣可以更穩定。

 

提升域自適應模型:

作者做了一個實驗,可以提升域自適應算法的能力。

原本的source域的數據爲S,那麼作者用DLOW這套算法,把z從[0,1]中均勻採樣,用生成器GST生成了新的數據集S~。

因此在S~中,數據分佈從S到T都有,再用S~數據集作爲域自適應算法的訓練數據,可以有效提升效果。

至於此處把Ladv賦予了一個權值,根號1-z,是因爲對於每一個樣本來說,如果z比較大的話,說明這個樣本更接近target域,因此對抗損失的權值需要降低。

 

風格生成:

大部分的風格遷移算法,都只能一對一地進行遷移。也就是說在訓練完後,就只能遷移到那個風格了。

但是DLOW可以生成訓練數據裏沒有見過的風格。假設有K個目標域,則z拓展成一個k維的向量[z1,z2,...,zk],所有z值加起來等於1。因此我們需要優化的目標變成了:

可以比較容易地修改網絡結構得到這個。

 

4、實驗部分

實驗部分作者做了倆實驗,一個是剛剛說的用生成的中間域數據訓練domain adaptation的模型,得到更好的結果。一個是風格遷移得到新風格unseen in the training data。

先說實驗一,做了一個語義分割的task,從GTA5遷移到Cityscapes。結果如下:

 

實驗二,作者用了真實照片遷移到油畫風格的task。用了莫奈、梵高等不同的target domain。

因此z變成了[z1,z2,z3,z4],其中z1+z2+z3+z4=1 

訓練的時候每五步分別用:[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1] ,均勻隨機取樣。得到了不錯的結果。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章