深度學習論文: LRNnet: a light-weighted network with efficient reduced non-local operation for real-time semantic segmentation及其PyTorch實現
LRNnet: a light-weighted network with efficient reduced non-local operation for real-time semantic segmentation
PDF:https://arxiv.org/pdf/2006.02706.pdf
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks
1 概述
本文基於LEDNet改進, 使用SVD簡化non-local網絡, 通過分解因子卷積塊(FCB),以適當的方式處理遠程依賴關係和短距離的特徵來構建輕量級且高效的特徵提取網絡。
LRNNET模型在GTX 1080Ti顯卡上的速度爲71FPS,獲得了72.2% mIoU,整體模型的參數量僅有0.68M。
2 LRNNet
LRNNet編碼器大致來看是由三個階段的ResNet形式組成。在每個階段的開始都使用下采樣單元用於對各個階段提取的特徵圖進行過渡。編碼器環節的核心組件是分解因子卷積FCB(Factorized Convolution Block)單元,可提供輕量級且高效的特徵提取。同時,在最後一個下采樣單元之後,採用了空洞卷積上輸出特徵圖的分辨率保持在1/8。
2-1 Factorized Convolution Block
具有較大空洞率的空洞卷積核在空間中接收復雜的遠程空間信息特徵,並且在空間中需要更多參數。同時,具有較小空洞率的空洞卷積核在空間中接收簡單的或較少信息的短距離特徵,而只需要較少參數就足夠了。因此FCB(上圖(c))首先將通道拆分成兩組,然後在兩組通道中分別用兩個一維卷積處理短距離和空間較少的信息特徵,這樣會大大降低參數和計算量。將兩個通道合併後,FCB利用2維卷積來擴大感受野捕獲遠距離特徵,並使用深度可分離卷積來減少參數和計算量。最後設置了通道混洗操作。
class HalfSplit(nn.Module):
def __init__(self, dim=1):
super(HalfSplit, self).__init__()
self.dim = dim
def forward(self, input):
splits = torch.chunk(input, 2, dim=self.dim)
return splits[0], splits[1]
class ChannelShuffle(nn.Module):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
N, C, H, W = x.size()
g = self.groups
return x.view(N, g, int(C / g), H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)
class FCB(nn.Module):
def __init__(self, channels, dilation=1, groups=4):
super(FCB, self).__init__()
mid_channels = channels // 2
self.half_split = HalfSplit(dim=1)
self.first_bottleneck = nn.Sequential(
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1,
padding=[1, 0]),
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1,
padding=[0, 1]),
)
self.second_bottleneck = nn.Sequential(
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1,
padding=[0, 1]),
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1,
padding=[1, 0]),
)
self.conv3x3 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1,dilation=dilation,
padding=dilation,groups=channels)
self.conv1x1 = Conv1x1BN(in_channels=channels,out_channels=channels)
self.channelShuffle = ChannelShuffle(groups)
def forward(self, x):
x1, x2 = self.half_split(x)
x1 = self.first_bottleneck(x1)
x2 = self.second_bottleneck(x2)
out = torch.cat([x1, x2], dim=1)
out = self.conv1x1(self.conv3x3(out))
return self.channelShuffle(out+x)
2-2 SVN module
使用SVM對non-local模塊的簡化,使得整體模型計算量更少、參數量更小、佔用內存更少。
1、通過Conv1和Conv2兩個1x1卷積以減少non-local計算操作的通道數;
2、用區域主導的奇異向量(spatial regional dominant singular vectors)替換key和value。