幾種視覺Attention的代碼詳解

幾種視覺Attention的代碼詳解


最近看了幾篇很優秀的視覺Attention介紹的文章,詳細見參考鏈接。這裏再拾人牙慧,將代碼再清晰整理一遍,並自己編寫了Self_Attn_Channel 注意力。

1 SENet - 通道注意力

在這裏插入圖片描述

在這裏插入圖片描述

#SENet:Squeeze-and-Excitation Networks  
#通道注意力
#論文地址:https://arxiv.org/abs/1709.01507
#代碼地址:https://github.com/hujie-frank/SENet

class SELayer(nn.Module):
    '''
    func: 實現通道Attention. 
    parameters:
        channel: input的通道數, input.size = (batch,channel,w,h) if batch_first else (channel,batch,,w,h)
        reduction: 默認4. 即在FC的時,存在channel --> channel//reduction --> channel的轉換
        batch_first: 默認True.如input爲channel_first,則batch_first = False
    '''
    def __init__(self, channel,reduction = 2, batch_first = True):
        super(SELayer, self).__init__()
        
        self.batch_first = batch_first
        self.avg_pool = nn.AdaptiveAvgPool2d(1) 
        self.fc = nn.Sequential(
            nn.Linear(channel,channel // reduction, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear(channel // reduction, channel, bias = False),
            nn.Sigmoid()
            )
        
    def forward(self, x):
        '''
        input.size == output.size 
        '''
        if not self.batch_first:
            x = x.permute(1,0,2,3)  
            
        b, c, _, _ = x.size() 
        y = self.avg_pool(x).view(b,c) #size = (batch,channel)
                
        y = self.fc(y).view(b,c,1,1)  #size = (batch,channel,1,1)
        out = x * y.expand_as(x) #size = (batch,channel,w,h)
        
        if not self.batch_first: 
            out = out.permute(1,0,2,3) #size = (channel,batch,w,h)

        return out 
    
    
x = torch.randn(size = (4,8,20,20))        
selayer = SELayer(channel = 8, reduction = 2, batch_first = True)
out = selayer(x)    
print(out.size()) 

'''
output: 
torch.Size([4, 8, 20, 20])
'''   

2 CBAM - 通道 + 空間注意力

在這裏插入圖片描述

在這裏插入圖片描述

在這裏插入圖片描述

#CBAM:Convolutional Block Attention Module(CBAM)

class ChannelAttention(nn.Module):
    '''
    func: 實現通道Attention. 
    parameters:
        in_channels: input的通道數, input.size = (batch,channel,w,h) if batch_first else (channel,batch,,w,h)
        reduction: 默認4. 即在FC的時,存在in_channels --> in_channels//reduction --> in_channels的轉換
        batch_first: 默認True.如input爲channel_first,則batch_first = False
    '''
    def __init__(self,in_channels, reduction = 4, batch_first = True):
        
        super(ChannelAttention,self).__init__()
        
        self.batch_first = batch_first
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, kernel_size = 1, bias = False),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels // reduction, in_channels, kernel_size = 1, bias = False),
            )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        if not self.batch_first: 
            x = x.permute(1,0,2,3) 
        
        avgout = self.sharedMLP(self.avg_pool(x)) #size = (batch,in_channels,1,1)
        maxout = self.sharedMLP(self.max_pool(x)) #size = (batch,in_channels,1,1)
        
        w = self.sigmoid(avgout + maxout) #通道權重  size = (batch,in_channels,1,1)
        out = x * w.expand_as(x) #返回通道注意力後的值 size = (batch,in_channels,w,h)
        
        if not self.batch_first:
            out = out.permute(1,0,2,3) #size = (channel,batch,w,h)

        return out
    
class SpatialAttention(nn.Module):
    '''
    func: 實現空間Attention. 
    parameters:
        kernel_size: 卷積核大小, 可選3,5,7,
        batch_first: 默認True.如input爲channel_first,則batch_first = False
    
    '''
    def __init__(self, kernel_size = 3, batch_first = True):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3,5,7), "kernel size must be 3 or 7"
        padding = kernel_size // 2
        
        self.batch_first = batch_first
        self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
        if not self.batch_first:
            x = x.permute(1,0,2,3)  #size = (batch,channels,w,h)
        
        avgout = torch.mean(x, dim=1, keepdim=True) #size = (batch,1,w,h)
        maxout,_ = torch.max(x, dim=1, keepdim=True)  #size = (batch,1,w,h)
        x1 = torch.cat([avgout, maxout], dim=1)    #size = (batch,2,w,h)
        x1 = self.conv(x1)    #size = (batch,1,w,h)
        w = self.sigmoid(x1)   #size = (batch,1,w,h)  
        out = x * w            #size = (batch,channels,w,h)

        if not self.batch_first:
            out = out.permute(1,0,2,3) #size = (channels,batch,w,h)

        return  out
    

class CBAtten_Res(nn.Module):
    '''
    func:channel attention + spatial attention + resnet
    parameters:
        in_channels: input的通道數, input.size = (batch,in_channels,w,h) if batch_first else (in_channels,batch,,w,h);
        out_channels: 輸出的通道數
        kernel_size: 默認3, 可選[3,5,7]
        stride: 默認2, 即改變out.size --> (batch,out_channels,w/stride, h/stride).
                一般情況下,out_channels = in_channels * stride
        reduction: 默認4. 即在通道atten的FC的時,存在in_channels --> in_channels//reduction --> in_channels的轉換
        batch_first:默認True.如input爲channel_first,則batch_first = False
    
    '''
    def __init__(self,in_channels,out_channels,kernel_size = 3, 
                 stride = 2, reduction = 4,batch_first = True):
        
        super(CBAtten_Res,self).__init__()
        
        self.batch_first = batch_first
        self.reduction = reduction
        self.padding = kernel_size // 2
        
        
        #h/2, w/2
        self.max_pool = nn.MaxPool2d(3, stride = stride, padding = self.padding)
        self.conv_res = nn.Conv2d(in_channels, out_channels,
                               kernel_size = 1,
                               stride = 1,
                               bias = True)
        
        
        #h/2, w/2
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size = kernel_size,
                               stride = stride, 
                               padding = self.padding,
                               bias = True)
        self.bn1 = nn.BatchNorm2d(out_channels) 
        self.relu = nn.ReLU(inplace = True)
        self.ca = ChannelAttention(out_channels, reduction = self.reduction,
                                   batch_first = self.batch_first)
        
        self.sa = SpatialAttention(kernel_size = kernel_size,
                                   batch_first = self.batch_first)
        
        
    def forward(self,x):
        
        if not self.batch_first:
            x = x.permute(1,0,2,3)  #size = (batch,in_channels,w,h)
        residual = x 
        
        out = self.conv1(x)   #size = (batch,out_channels,w/stride,h/stride)
        out = self.bn1(out) 
        out = self.relu(out) 
        out = self.ca(out)
        out = self.sa(out)  #size = (batch,out_channels,w/stride,h/stride)
        
        residual = self.max_pool(residual)  #size = (batch,in_channels,w/stride,h/stride)
        residual = self.conv_res(residual)  #size = (batch,out_channels,w/stride,h/stride)
        
        out += residual #殘差
        out = self.relu(out)  #size = (batch,out_channels,w/stride,h/stride)
        
        if not self.batch_first:
            out = out.permute(1,0,2,3) #size = (out_channels,batch,w/stride,h/stride) 
            
        return out
    
    
x = torch.randn(size = (4,8,20,20))  
cba = CBAtten_Res(8,16,reduction = 2,stride = 1) 
y = cba(x)
print('y.size:',y.size())   

'''
y.size: torch.Size([4, 16, 20, 20])
'''

3 SKEConv

在這裏插入圖片描述
在這裏插入圖片描述

#SKENet: Selective Kernel Networks
# 論文地址:https://arxiv.org/abs/1903.06586
# 代碼地址:https://github.com/implus/SKNet

class SKEConv(nn.Module):
    '''
    func: 實現Selective Kernel Networks(SKE) Attention機制。主要由Spit + Fuse + Select 三個模塊組成 
    parameters:
        in_channels: input的通道數;
        M: Split階段. 使用不同大小的卷積核(M個)對input進行卷積,得到M個分支,默認2;
        G: 在卷積過程中使用分組卷積,分組個數爲G, 默認爲2.可以減小參數量;
        stride: 默認1. split卷積過程中的stride,也可以選2,降低輸入輸出的w,h;
        L: 默認32; 
        reduction: 默認2,壓縮因子; 在線性部分壓縮部分,輸出特徵d = max(L, in_channels / reduction);
        batch_first: 默認True;
        
    '''
    def __init__(self,in_channels, M = 2, G = 2, stride = 1, L = 32, reduction = 2, batch_first = True):
        
        super(SKEConv,self).__init__()
        
        self.M = 2
        self.in_channels = in_channels
        self.batch_first = batch_first
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, in_channels, 
                              kernel_size = 3 + i*2,
                              stride = stride,
                              padding = 1 + i,
                              groups = G),
                    nn.BatchNorm2d(in_channels),
                    nn.ReLU(inplace = True)
                    ))
        
        self.d = max(int(in_channels / reduction), L)
        self.fc = nn.Linear(in_channels, self.d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(nn.Linear(self.d,in_channels))
            
        self.softmax = nn.Softmax(dim = 1)
        
        
    def forward(self, x):
        
        if not self.batch_first:
            x = x.permutation(1,0,2,3)
            
        for i ,conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim = 1)  #size = (batch,1,in_channels,w,h)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas,fea],dim = 1) #size = (batch,M,in_channels,w,h)
        
        fea_U = torch.sum(feas,dim = 1) #size = (batch,in_channels,w,h)
        fea_s = fea_U.mean(-1).mean(-1) #size = (batch,in_channels)
        fea_z = self.fc(fea_s)  #size = (batch,d)
        
        for i,fc in enumerate(self.fcs):
            vector = fc(fea_z).unsqueeze_(dim=1) #size = (batch,1,in_channels)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors,vector],
                                              dim = 1)  #size = (batch,M,in_channels)
                
        attention_vectors = self.softmax(attention_vectors) #size = (batch,M,in_channels)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) #size = (batch,M,in_channels,w,h) 
        fea_v = (feas * attention_vectors).sum(dim=1) #size = (batch,in_channels,w,h)
        
        if not self.batch_first:
            fea_v = fea_v.permute(1,0,2,3)
                    
        return fea_v
    
#%%
x = torch.randn(size = (4,8,20,20))  
ske = SKEConv(8,stride = 2)
y = ske(x)
print('y.size:',y.size())   

'''
y.size: torch.Size([4, 16, 10, 10])
'''

4 self-attention

在這裏插入圖片描述

4.1 Self_Attn_Spatial 空間注意力

#視覺應用中的self-attention機制

class Self_Attn_Spatial(nn.Module):
    """ 
    func: Self attention Spatial Layer 自注意力機制.通過類似Transformer中的Q K V來實現
    inputs:
        in_dim: 輸入的通道數
        out_dim: 在進行self attention時生成Q,K矩陣的列數, 一般默認爲in_dim//8
    """
    def __init__(self,in_dim,out_dim):
        super(Self_Attn_Spatial,self).__init__()
        self.chanel_in = in_dim
        self.out_dim = out_dim
 
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))
 
        self.softmax  = nn.Softmax(dim=-1)
        
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        
        #proj_query中的第i行表示第i個像素位置上所有通道的值。size = B X N × C1
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) 
        
        #proj_key中的第j行表示第j個像素位置上所有通道的值,size = B X C1 x N
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) 
        
        #Energy中的第(i,j)是將proj_query中的第i行與proj_key中的第j行點乘得到
        #energy中第(i,j)位置的元素是指輸入特徵圖第j個元素對第i個元素的影響,
        #從而實現全局上下文任意兩個元素的依賴關係
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        
        #對行的歸一化,對於(i,j)位置即可理解爲第j位置對i位置的權重,所有的j對i位置的權重之和爲1
        attention = self.softmax(energy) # B X N X N
        
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
        out = torch.bmm(proj_value,attention.permute(0,2,1)) #B X C X N
        out = out.view(m_batchsize,C,width,height) #B X C X W X H
        
        #跨連,Gamma是需要學習的參數
        out = self.gamma*out + x #B X C X W X H
        
        return out,attention

x = torch.randn(size = (4,16,20,20))  
self_atten_spatial = Self_Attn_Spatial(16,4)
y = self_atten_spatial(x)
print('y.size:',y[0].size())   

'''
y.size: torch.Size([4, 16, 20, 20])
'''

4.2 Self_Attn_Channel 通道注意力

  • 注意:目前的non_local 和 self_attention基本都是空間注意力,沒有實現通道注意力。
  • 這裏作者根據自己對Transformer注意力的理解,給出了Self_Attn_Channel,即通道注意力。
class Self_Attn_Channel(nn.Module):
    """ 
    func: Self attention Channel Layer 自注意力機制.通過類似Transformer中的Q K V來實現
    inputs:
        in_dim: 輸入的通道數
        out_dim: 在進行self attention時生成Q,K矩陣的列數, 默認可選取爲:in_dim
        
    """
    def __init__(self,in_dim,out_dim ):
        super(Self_Attn_Channel,self).__init__()
        self.chanel_in = in_dim
        self.out_dim = out_dim
 
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.x_conv = nn.Conv2d(in_channels = in_dim , out_channels = out_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))
 
        self.softmax  = nn.Softmax(dim=-1)
        
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C0 X W X H)
            returns :
                out : self attention value + input feature
                attention: B X C1 X C1 (N is Width*Height)
        """
        #C0 = in_dim; C1 = out_dim
        
        m_batchsize,C0,width ,height = x.size() 
        
        #proj_query中的第i行表示第i個通道位置上所有像素的值: size = B X C1 × N
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height) 
        
        #proj_key中的第j行表示第j個通道位置上所有像素的值,size = B X N x C1
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) 
        
        #Energy中的第(i,j)是將proj_query中的第i行與proj_key中的第j行點乘得到
        #energy中第(i,j)位置的元素是指輸入特徵圖第j個通道對第i個通道的影響,
        #從而實現全局上下文任意兩個通道的依賴關係. size = B X C1 X C1
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        
        #對於(i,j)位置即可理解爲第j通道對i通道的權重,所有的j對i通道的權重之和爲1
        #對行進行歸一化,即每行的所有列加起來爲1
        attention = self.softmax(energy) # B X C1 X C1
        
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C1 X N
        out = torch.bmm(attention, proj_value) #B X C1 X N
        out = out.view(m_batchsize,self.out_dim, width,height) #B X C1 X W X H
        
        #跨連,Gamma是需要學習的參數
        out = self.gamma*out + self.x_conv(x) #B X C1 X W X H
        
        return out,attention

x = torch.randn(size = (4,8,20,20))  
self_atten_channel = Self_Attn_Channel(8, 8)
y = self_atten_channel(x)
print('y.size:',y[0].size()) 

'''
output:
y.size: torch.Size([4, 8, 20, 20])
''' 

5 Non-local

在這裏插入圖片描述

import torch
from torch import nn
from torch.nn import functional as F


class NonLocalBlockND(nn.Module):
    """
    func: 非局部信息統計的注意力機制
    inputs: 
        in_channels:輸入的通道數,輸入是batch_first = True。
        inter_channels: 生成attention時Conv的輸出通道數,一般爲in_channels//2.
                        如果爲None, 則自動爲in_channels//2
        dimension: 默認2.可選爲[1,2,3],
                  1:輸入爲size = [batch,in_channels, width]或者[batch,time_steps,seq_length],可表示時序數據
                  2: 輸入size = [batch, in_channels, width,height], 即圖片數據
                  3: 輸入size = [batch, time_steps, in_channels, width,height],即視頻數據
                    
        sub_sample: 默認True,是否在Attention過程中對input進行size降低,即w,h = w//2, h//2               
        bn_layer: 默認True
    
    """
    def __init__(self,
                 in_channels,
                 inter_channels=None,
                 dimension=2,
                 sub_sample=True,
                 bn_layer=True):
        super(NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            # 進行壓縮得到channel個數
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels,
                         out_channels=self.inter_channels,
                         kernel_size=1,
                         stride=1,
                         padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels,
                        out_channels=self.in_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0), bn(self.in_channels))
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels,
                             out_channels=self.in_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels,
                             out_channels=self.inter_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
        self.phi = conv_nd(in_channels=self.in_channels,
                           out_channels=self.inter_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)
            
            
    def forward(self, x):
        
        
        #if dimension == 3 , N = w*h*t ; if sub_sample: N1 = (w//2) * (h//2) * t ,else: N1 = N
        #if dimension == 2 , N = w*h  
        #if dimension == 1 , N = w 
        #C0 = in_channels;   C1 = inter_channels

            
        batch_size = x.size(0) 

        g_x = self.g(x).view(batch_size, self.inter_channels, -1) #[B, C1, N1]
        g_x = g_x.permute(0, 2, 1) #[B, N1, C1]

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) #[B, C1, N]
        theta_x = theta_x.permute(0, 2, 1) #[B, N, C1]

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) #[B, C1, N1]
        
        f = torch.matmul(theta_x, phi_x) #[B,N,N1]

        # print(f.shape) 

        f_div_C = F.softmax(f, dim=-1) 

        y = torch.matmul(f_div_C, g_x) #[B,N,N1] *[B, N1, C1] = [B,N,C1] 
        y = y.permute(0, 2, 1).contiguous() #[B,C1,N] 

        size = [batch_size, self.inter_channels] + list(x.size()[2:])
        y = y.view(size)  #size = [B,N,C1,x.size()[2:]] 
        
        W_y = self.W(y)  #1 × 1 卷積 size = x.size()
        z = W_y + x  #殘差連接
        return z 

x = torch.randn(size = (4,16,20,20))  
non_local = NonLocalBlockND(16,inter_channels = 8,dimension = 2)
y = non_local(x)
print('y.size:',y.size())

'''
output:
y.size: torch.Size([4, 16, 20, 20])
'''

6 參考鏈接

注意力機制在分類網絡中的應用:SENet、SKNet、CBAM

來聊聊 ResNet 及其變種

Self-attention機制及其應用:Non-local網絡模塊

Attention綜述:基礎原理、變種和最近研究

一文看懂 Attention(本質原理+3大優點+5大類型)

模型彙總24 - 深度學習中Attention Mechanism詳細介紹:原理、分類及應用

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章