SAGAN:Self-Attention Generative Adversarial Networks論文解讀

在上一篇博客中,SEnet關注於各通道間的依賴關係,通過增強和減弱某個通道特徵的表達使得網絡有“注意力”。而SAGAN將注意力機制應用在GAN生成項目中,主要是每個feature map 和它自身的轉置相乘,讓任意兩個位置的像素直接發生關係,這樣就可以學習到任意兩個像素之間的依賴關係,從而得到全局特徵。從單個特徵出發,使得網絡有“注意力”能力。

 

1、GAN生成圖片時存在的問題

      ① 在卷積神經網絡中,每個卷積核的尺寸都是很有限的(基本上不會大於5),因此每次卷積操作只能覆蓋像素點周圍很小一塊鄰域。對於距離較遠的特徵,例如狗有四條腿這類特徵,就不容易捕獲到了(也不是完全捕獲不到,因爲多層的卷積、池化操作會把 feature map 的高和寬變得越來越小,越靠後的層,其卷積核覆蓋的區域映射回原圖對應的面積越大。但總而言之,畢竟還得需要經過多層映射,不夠直接)。Self-Attention 通過直接計算圖像中任意兩個像素點之間的關係,一步到位地獲取圖像的全局幾何特徵

     ② GAN在生成圖片時,擅長合成幾乎沒有結構約束的圖像類別,例如海洋、天空和景觀類別。但是無法捕獲在某些類別中始終如一地出現的幾何或結構模式,如狗有四條腿、狗的皮毛等。這種原因並不是GAN自身帶來的,而是以前的GAN模型在很大程度上依賴於卷積來模擬不同圖像區域之間的依賴關係。卷積運算是基於一個局部感受域,只能在經過幾個卷積層之後才能處理圖像中遠距離的相關性。這種學習長期相關性問題上,在卷積網絡下可能無法表示它們, 因爲在模型優化階段捕獲多個層相關性參數是不容易的,並且這些參數化可能是統計學上的,這就帶來了一定的問題。增加捲積核的大小可以增加網絡的表示能力, 但這樣做也會失去使用本地卷積結構獲得的計算和統計效率

 2、SAGAN的亮點

    ① 將Attention機制用在GAN生成上。在此之前,self-Attention在Attention is All You Need中提出,應用在機器翻譯上, 文中指出引入自我注意機制可以很好的學習到序列的依賴關係,從全局上去分析序列Non-local Neureal Networks 一文將Self-Attention應用在了視頻分類上,推導說明Self-Attention對於圖像前後依賴關係上是有很大意義的,同時在單幅圖像中可以很好地學習到圖像的全局特徵

     那麼可以利用self Attention機制更好地學習全局特徵之間的依賴關係。傳統的 GAN 模型很容易學習到紋理特徵:如皮毛,天空,草地等,不容易學習到特定的結構和幾何特徵,例如狗有四條腿,既不能多也不能少。

     自我注意(Self-Attention)引入到卷積GAN中是對卷積的補充,有助於模擬跨越圖像區域的長距離,多級別依賴關係。在Self-Attention作用下發生器在生成圖像過程中, 可以對每個位置的精細細節都與圖像遠處的精細細節進行仔細協調。

    ② 用到了譜歸一化,提出於Spectral Normalization for GANs。SAGAN代碼中的譜歸一化和原始的譜歸一化運用方式略有差別:1、原始的譜歸一化基於 W-GAN 的理論,只用在 Discriminator 中,用以約束 Discriminator 函數爲 1-Lipschitz 連續。而在 Self-Attention GAN 中,Spectral Normalization 同時出現在了 Discriminator 和 Generator 中,用於使梯度更穩定。除了生成器和判別器的最後一層外,每個 卷積/反捲積 單元都會上一個 SpectralNorm。2、當把譜歸一化用在 Generator 上時,同時還保留了 BatchNorm。Discriminator 上則沒有 BatchNorm,只有 SpectralNorm。3、譜歸一化用在 Discriminator 上時最後一層不加 Spectral Norm。

3、SAGAN原理及結構

     SAGAN模型的整體框架和GAN是一樣的,僅僅是由一個生成器和一個判別器組成,只是在生成器和判別器網絡設計的內部加入了Self-Attention層。

    主要的結構圖如下,

 

    實際上就是 feature map 和它自身的轉置相乘,讓任意兩個位置的像素直接發生關係,這樣就可以學習到任意兩個像素之間的依賴關係,從而得到全局特徵。

Self-Attention

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B * (W*H) * C
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B * C * (W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check,batch matrix  
 multiplication [B*N*N],N = W*H
        attention = self.softmax(energy) # B * (N) * (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B * C * N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )  # [B,C,N] * [B,N,N] = [B,C,N]
        out = out.view(m_batchsize,C,width,height) # [B,C,W,H]
        
        out = self.gamma*out + x
        return out,attention

其中,proj_query爲[B,N,C], proj_key爲[B,C,N],其中N = W*H。前者將二維feature map每個channel拉成一個長度爲N = W*H的向量。矩陣每行代表一個像素位置上所有通道的值,每列代表某個通道中所有的像素值。後者沒有轉置操作,得到C*N的向量。矩陣每一行代表一個通道中所有的像素值,每一列代表一個像素位置上所有通道的值。 

 

energy =  torch.bmm(proj_query,proj_key) # batch matrix multiply

torch.bmm計算tensor矩陣乘法,將相同batchsize的兩組matrix對應地做矩陣乘法,最終得到相同batchsize的新矩陣。一個應用實例爲

>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(batch1, batch2)
>>> res.size()
torch.Size([10, 3, 5])

已知 proj_query維度是 [B,N,C] ,proj_key的維度是[B,C,N],因此energy的維度是[B,N,N],N = W*H,示意圖爲(引自參考1)

      

energy是Attention注意力的核心,其中energy[i][j]是由轉置後proj_query[N*C]的第i行和proj_key[C*N]的第j列通過向量點乘得到的。proj_query[N*C]的第i行代表feature map上第i個像素位置上所有通道的值(即第i個像素位置的所有信息),proj_key[C*N]的第j列代表feature map上第j個像素位置上的所有通道值(即第j個像素位置的所有信息)。二者相乘可以看作第j個像素對第i個像素的影響。即energy[i][j]的元素值表示feature map第j個像素點對第i個像素點的影響。

 

接下來是將energy進行按“行”歸一化,softmax輸出是一個概率分佈: 每個元素都是非負的, 並且所有元素的總和都是1。經過softmax處理,energy各行元素的和均爲1。因爲 energy 中第 i 行元素,代表 feature map 中所有位置的像素對第 i 個像素的影響,而這個影響被解釋爲權重,故加起來應該是 1,故應對其按行歸一化。attention變量的維度爲[B,N,N]。

attention = self.softmax(energy) # 其中nn.softmax(x,dim),dim:指明維度,dim=0表示按列計算;dim=1表示按行計算

緊接着對原feature map做1*1卷積映射,並將其維度由[B,C,W,H] reshape爲 [B,C,N],其中N = W*H爲feature map的像素數,得到proj_value。其中C爲通道數,C = in_dim。此處的C與上面計算得到的proj_query和proj_key的維度C = in_dim/8不同。proj_value每行代表每個通道上所有位置的像素值,每列則代表每個像素位置上所有通道的像素值。

proj_value = self.value_conv(x).view(m_batchsize,-1,width*height)

                                         

接下來,將attention矩陣轉置,並將proj_value與其相乘。attention矩陣轉置的原因爲,attention經softmax按行歸一化後,每一行都代表其他所有位置的元素對當前第i行(第i個位置)元素的影響,每一行元素的值的和爲1,意義爲不同的權重(影響程度),轉置後每一列的和爲1。施加於proj_value的行(每個通道上所有位置的像素值)上,作爲該行的加權平均。

proj_value[i]代表第i個通道所有的像素值,attention[j]代表所有像素值施加到第j個像素的影響。二者相乘後得到out,out的第i行包括了輸出的第i個通道中的所有像素,第j列表示所有像素中的第j個像素。則out[i][j]表示被 attention 加權之後的 feature map 的第 i 個通道的第 j 個像素的像素值。再改變一下形狀,out恢復爲[B,C,W,H]的結構。

out = torch.bmm(proj_value,attention.permute(0,2,1) )  # [B,C,N] * [B,N,N] = [B,C,N]
out = out.view(m_batchsize,C,width,height) # [B,C,W,H]

 

                                

       最後將輸出的out做殘差處理。

out = self.gamma*out + x # x爲輸入的feature map

             借鑑resnet中的操作,gamma爲可調整的參數(由nn.parameters()定義),表示整體施加了attention之後的feature map的權重,需要通過反向傳播更新。初始階段,gamma爲0,attention模塊直接返回輸入的feature map。之後隨着學習,attention模塊逐漸學習到了將attention加權過的feature map加在原始的feature map上,從而強調了需要施加註意力的部分 feature map。

4、模型訓練技巧

文中主要突出兩點說明:

1.譜歸一化的使用,譜歸一化限制了每個曾的譜範數來約束判別器的Lipschitz常數,詳細請參見spectral normalization一文。

2.不平衡學習率的設定,對於G和D的訓練穩定性一直是GAN訓練需要考慮的,不平衡的學習率往往可以是訓練更加的穩定。

參考:

1、https://zhuanlan.zhihu.com/p/55741364

2、http://www.twistedwg.com/2018/06/21/SAGAN.html

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章