論文地址:CCNet: Criss-Cross Attention for Semantic Segmentation
代碼地址:CCNet github
一、簡介
CCNet是2018年11月發佈的一篇語義分割方面的文章中提到的網絡,該網絡有三個優勢:
- GPU內存友好;
- 計算高效;
- 性能好。
CCNet之前的論文比如FCNs只能管制局部特徵和少部分的上下文信息,空洞卷積只能夠集中於當前像素而無法生成密集的上下文信息,雖然PSANet能夠生成密集的像素級的上下文信息但是計算效率過低,其計算複雜度高達O((H*W)*(H*\W))。因此可以明顯的看出,CCNet的目的是高效的生成密集的像素級的上下文信息。
Cirss-Cross Attention Block的參數對比如下圖所示:
CCNet論文的主要貢獻:
- 提出了Cirss-Cross Attention Module;
- 提出了高效利用Cirss-Cross Attention Module的CCNet。
二、結構
1、CCNet結構
CCNet的網絡結構如下圖所示:
CCNet的基本結構描述如下:
- 1、圖像通過特徵提取網絡得到feature map的大小爲,爲了更高效的獲取密集的特徵圖,將原來的特徵提取網絡中的後面兩個下采樣去除,替換爲空洞卷積,使得feature map的大小爲輸入圖像的1/8;
- 2、feature map X分爲兩個分支,分別進入3和4;
- 3、一個分支先將X進行通道縮減壓縮特徵,然後通過兩個CCA(Cirss-Cross Attention)模塊,兩個模塊共享相同的參數,得到特徵;
- 4、另一個分支保持不變爲X;
- 5、將3和4兩個分支的特徵融合到一起最終經過upsample得到分割圖像。
2、Criss-Cross Attention
Criss-Cross Attention模塊的結構如下所示,輸入feature爲,分爲三個分支,都通過1*1的卷積網絡的進行降維得到()。其中Attention Map 是和通過Affinity操作計算的。Affinity操作定義爲:
其中是在特徵圖Q的空間維度上的u位置的值。是上位置處的同列和同行的元素的集合。因此,是中的第個元素,其中。而表示和之間的聯繫的權重,。最後對進行在通道維度上繼續進行softmax操作計算Attention Map 。
另一個分支經過一個1*1卷積層得到的適應性特徵。同樣定義和,是上u點的同行同列的集合,則定義Aggregation操作爲:
該操作在保留原有feature的同時使用經過attention處理過的feature來保全feature的語義性質。
3、Recurrent Criss-Cross Attention
單個Criss-Cross Attention模塊能夠提取更好的上下文信息,但是下圖所示,根據criss-cross attention模塊的計算方式左邊右上角藍色的點只能夠計算到和其同列同行的關聯關係,也就是說相應的語義信息的傳播無法到達左下角的點,因此再添加一個Criss-Cross Attention模塊可以將該語義信息傳遞到之前無法傳遞到的點。
採用Recurrent Criss-Cross Attention之後,先定義loop=2,第一個loop的attention map爲,第二個loop的attention map爲,從原feature上位置到權重的映射函數爲,feature 中的位置用表示,feature中用表示,如果和相同則:
其中表示加到操作,如果和不同則:
Cirss-Cross Attention模塊可以應用於多種任務不僅僅是語義分割,作者同樣在多種任務中使用了該模塊,可以參考論文。
4、代碼
下面是Cirss-Cross Attention模塊的代碼可以看到ca_weight便是Affinity操作,ca_map便是Aggregation操作。
class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def __init__(self,in_dim):
super(CrissCrossAttention,self).__init__()
self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self,x):
proj_query = self.query_conv(x)
proj_key = self.key_conv(x)
proj_value = self.value_conv(x)
energy = ca_weight(proj_query, proj_key)
attention = F.softmax(energy, 1)
out = ca_map(attention, proj_value)
out = self.gamma*out + x
return out
Affinity操作定義如下:
class CA_Weight(autograd.Function):
@staticmethod
def forward(ctx, t, f):
# Save context
n, c, h, w = t.size()
size = (n, h+w-1, h, w)
weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device)
_ext.ca_forward_cuda(t, f, weight)
# Output
ctx.save_for_backward(t, f)
return weight
@staticmethod
@once_differentiable
def backward(ctx, dw):
t, f = ctx.saved_tensors
dt = torch.zeros_like(t)
df = torch.zeros_like(f)
_ext.ca_backward_cuda(dw.contiguous(), t, f, dt, df)
_check_contiguous(dt, df)
return dt, df
Aggregation操作定義如下:
class CA_Map(autograd.Function):
@staticmethod
def forward(ctx, weight, g):
# Save context
out = torch.zeros_like(g)
_ext.ca_map_forward_cuda(weight, g, out)
# Output
ctx.save_for_backward(weight, g)
return out
@staticmethod
@once_differentiable
def backward(ctx, dout):
weight, g = ctx.saved_tensors
dw = torch.zeros_like(weight)
dg = torch.zeros_like(g)
_ext.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg)
_check_contiguous(dw, dg)
return dw, dg
其中使用ext是c庫文件:
RCC模塊的實現如下所示:
class RCCAModule(nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super(RCCAModule, self).__init__()
inter_channels = in_channels // 4
self.conva = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
InPlaceABNSync(inter_channels))
self.cca = CrissCrossAttention(inter_channels)
self.convb = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
InPlaceABNSync(inter_channels))
self.bottleneck = nn.Sequential(
nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),
InPlaceABNSync(out_channels),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
)
def forward(self, x, recurrence=1):
output = self.conva(x)
for i in range(recurrence):
output = self.cca(output)
output = self.convb(output)
output = self.bottleneck(torch.cat([x, output], 1))
return output
CCNet的整體結構:
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes):
self.inplanes = 128
super(ResNet, self).__init__()
self.conv1 = conv3x3(3, 64, stride=2)
self.bn1 = BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=False)
self.conv2 = conv3x3(64, 64)
self.bn2 = BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=False)
self.conv3 = conv3x3(64, 128)
self.bn3 = BatchNorm2d(128)
self.relu3 = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.relu = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,1,1))
#self.layer5 = PSPModule(2048, 512)
self.head = RCCAModule(2048, 512, num_classes)
self.dsn = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
InPlaceABNSync(512),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
)
def forward(self, x, recurrence=1):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x_dsn = self.dsn(x)
x = self.layer4(x)
x = self.head(x, recurrence)
return [x, x_dsn]
三、結果
與主流的方法的比較:
下面是不同loop時的效果可以看到loop=2時的效果要比loop=2好。下面是不同loop的attention map。