unet網絡魔改那些事

參考極市社區

在圖像分割任務特別是醫學圖像分割中,U-Net[1]無疑是最成功的方法之一,該方法在2015年MICCAI會議上提出,目前已達到四千多次引用。其採用的編碼器(下采樣)-解碼器(上採樣)結構和跳躍連接是一種非常經典的設計方法。目前已有許多新的卷積神經網絡設計方式,但很多仍延續了U-Net的核心思想,加入了新的模塊或者融入其他設計理念。本文對U-Net及其幾種改進版做一個介紹。

U-Net和3D U-Net

U-Net最初是一個用於二維圖像分割的卷積神經網絡,分別贏得了ISBI 2015細胞追蹤挑戰賽和齲齒檢測挑戰賽的冠軍[2]。U-Net的一個Karas實現代碼:

https://github.com/zhixuhao/unet

U-Net的結構如下圖所示,左側可視爲一個編碼器,右側可視爲一個解碼器。編碼器有四個子模塊,每個子模塊包含兩個卷積層,每個子模塊之後有一個通過max pool實現的下采樣層。輸入圖像的分辨率是572x572, 第1-5個模塊的分辨率分別是572x572, 284x284, 140x140, 68x68和32x32。由於卷積使用的是valid模式,故這裏後一個子模塊的分辨率等於(前一個子模塊的分辨率-4)/2。解碼器包含四個子模塊,分辨率通過上採樣操作依次上升,直到與輸入圖像的分辨率一致(由於卷積使用的是valid模式,實際輸出比輸入圖像小一些)。該網絡還使用了跳躍連接,將上採樣結果與編碼器中具有相同分辨率的子模塊的輸出進行連接,作爲解碼器中下一個子模塊的輸入。

file

3D U-Net[3]是U-Net的一個簡單擴展,應用於三維圖像分割,結構如下圖所示。相比於U-Net,該網絡僅用了三次下采樣操作,在每個卷積層後使用了batch normalization,但3D U-Net和U-Net均沒有使用dropout。

file

在2018年MICCAI腦腫瘤分割挑戰賽(brats)中[4],德國癌症研究中心的團隊使用3D U-Net,僅做了少量的改動,取得了該挑戰賽第二名的成績,發現相比於許多新的網絡,3D U-Net仍然十分具有優勢[5]。3D U-Net的一種Pytorch實現:

https://github.com/wolny/pytorch-3dunet

TernausNet

TernausNet全稱爲"TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation"[6]。該網絡將U-Net中的編碼器替換爲VGG11,並在ImageNet上進行預訓練,從735個參賽隊伍中脫穎而出,取得了Kaggle 二手車分割挑戰賽(Carvana Image Masking Challenge)第一名。代碼鏈接:

https://github.com/ternaus/TernausNet

下圖是該網絡的示意圖:

file

Res-UNet 和Dense U-Net

Res-UNet和Dense-UNet分別受到殘差連接和密集連接的啓發,將UNet的每一個子模塊分別替換爲具有殘差連接和密集連接的形式。[6] 中將Res-UNet用於視網膜圖像的分割,其結構如下圖所示,其中灰色實線表示各個模塊中添加的殘差連接。

file

密集連接即將子模塊中某一層的輸出分別作爲後續若干層的輸入的一部分,某一層的輸入則來自前面若干層的輸出的組合。下圖是[7]中的密集連接的一個例子。該文章中將U-Net的各個子模塊替換爲這樣的密集連接模塊,提出Fully Dense UNet 用於去除圖像中的僞影。

file

MultiResUNet

MultiResUNet[8]提出了一個MutiRes模塊與UNet結合。MutiRes模塊如下圖所示,是一個殘差連接的擴展,在該模塊中三個3x3的卷積結果拼接起來作爲一個組合的特徵圖,再與輸入特徵圖經過1x1卷積得到的結果相加。

file

該網絡的結構圖如下圖所示,其中各個MultiRes模塊的內部即爲上圖所示。

file

該網絡除了MultiRes模塊以外,還提出了一個殘差路徑(ResPath), 使編碼器的特徵在與解碼器中的對應特徵拼接之前,先進行了一些額外的卷積操作,如下圖所示。作者認爲編碼器中的特徵由於卷積層數較淺,是低層次的特徵,而解碼器中對應的特徵由於卷積層更深,是較高層次的特徵,二者在語義上有較大差距,推測不宜直接將二者進行拼接。因此,使用額外的ResPath使二者在拼接前具有一致的深度,在ResPath1, 2, 3, 4中分別使用4,3,2,1個卷積層。

file

該文章在ISIC、CVC-ClinicDB、Brats等多個數據集上驗證了其性能。代碼鏈接爲

https://github.com/nibtehaz/MultiResUNet

 

模型代碼keras;https://github.com/nibtehaz/MultiResUNet/blob/master/MultiResUNet.py

from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add
from keras.models import Model, model_from_json
from keras.optimizers import Adam
from keras.layers.advanced_activations import ELU, LeakyReLU
from keras.utils.vis_utils import plot_model



def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='relu', name=None):
    '''
    2D Convolutional layers
    
    Arguments:
        x {keras layer} -- input layer 
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters
    
    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(1, 1)})
        activation {str} -- activation function (default: {'relu'})
        name {str} -- name of the layer (default: {None})
    
    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    if(activation == None):
        return x

    x = Activation(activation, name=name)(x)

    return x


def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None):
    '''
    2D Transposed Convolutional layers
    
    Arguments:
        x {keras layer} -- input layer 
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters
    
    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(2, 2)})
        name {str} -- name of the layer (default: {None})
    
    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)
    
    return x


def MultiResBlock(U, inp, alpha = 1.67):
    '''
    MultiRes Block
    
    Arguments:
        U {int} -- Number of filters in a corrsponding UNet stage
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''

    W = alpha * U

    shortcut = inp

    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
                         int(W*0.5), 1, 1, activation=None, padding='same')

    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,
                        activation='relu', padding='same')

    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
    out = BatchNormalization(axis=3)(out)

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out


def ResPath(filters, length, inp):
    '''
    ResPath
    
    Arguments:
        filters {int} -- [description]
        length {int} -- length of ResPath
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''


    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out


def MultiResUnet(height, width, n_channels):
    '''
    MultiResUNet
    
    Arguments:
        height {int} -- height of image 
        width {int} -- width of image 
        n_channels {int} -- number of channels in image
    
    Returns:
        [keras model] -- MultiResUNet model
    '''


    inputs = Input((height, width, n_channels))

    mresblock1 = MultiResBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
    mresblock1 = ResPath(32, 4, mresblock1)

    mresblock2 = MultiResBlock(32*2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
    mresblock2 = ResPath(32*2, 3, mresblock2)

    mresblock3 = MultiResBlock(32*4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
    mresblock3 = ResPath(32*4, 2, mresblock3)

    mresblock4 = MultiResBlock(32*8, pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(mresblock4)
    mresblock4 = ResPath(32*8, 1, mresblock4)

    mresblock5 = MultiResBlock(32*16, pool4)

    up6 = concatenate([Conv2DTranspose(
        32*8, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock4], axis=3)
    mresblock6 = MultiResBlock(32*8, up6)

    up7 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock3], axis=3)
    mresblock7 = MultiResBlock(32*4, up7)

    up8 = concatenate([Conv2DTranspose(
        32*2, (2, 2), strides=(2, 2), padding='same')(mresblock7), mresblock2], axis=3)
    mresblock8 = MultiResBlock(32*2, up8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(
        2, 2), padding='same')(mresblock8), mresblock1], axis=3)
    mresblock9 = MultiResBlock(32, up9)

    conv10 = conv2d_bn(mresblock9, 1, 1, 1, activation='sigmoid')
    
    model = Model(inputs=[inputs], outputs=[conv10])

    return model
   

(這個網絡我測試了一下,採用dice損失,loss一直下降不去,網絡性能保持懷疑,有圖爲證如下)

R2U-Net

R2U-Net全稱叫做Recurrent Residual CNN-based U-Net[9]。該方法將殘差連接和循環卷積結合起來,用於替換U-Net中原來的子模塊,如下圖所示

file

其中環形箭頭表示循環連接。下圖表示了幾種不同的子模塊內部結構圖,(a)是常規的U-Net中使用的方法,(b)是在(a)的基礎上循環使用包含激活函數的卷積層,(c)是使用殘差連接的方式,(d)是該文章提出的結合(b)和(c)的循環殘差卷積模塊。

file

該方法也在皮膚病圖像、視網膜圖像、肺部圖像等幾個公共數據集驗證了其性能,代碼鏈接:

https://github.com/LeeJunHyun/Image\_Segmentation#r2u-net

Attention UNet

Attention UNet[10]在UNet中引入注意力機制,在對編碼器每個分辨率上的特徵與解碼器中對應特徵進行拼接之前,使用了一個注意力模塊,重新調整了編碼器的輸出特徵。該模塊生成一個門控信號,用來控制不同空間位置處特徵的重要性,如下圖中紅色圓圈所示。

file

該方法的注意力模塊內部如下圖所示,該模塊通過1x1x1的卷積分別與ReLU和Sigmoid結合,生成一個權重圖file, 通過與編碼器中的特徵相乘來對其進行校正。

file

下圖展示了注意力權重圖的可視化效果。從左至右分別是一幅圖像和隨着訓練次數的增加該圖像中得到的注意力權重。可見得到的注意力權重傾向於在目標器官區域取得大的值,在背景區域取得較小的值,有助於提高圖像分割的精度。

file

該文章的代碼鏈接:

https://github.com/ozan-oktay/Attention-Gated-Networks

其他

基於U-Net框架設計的圖像分割網絡還有很多,難以一一列舉,這裏再提供兩篇具有參考性的文章:

AnatomyNet: Deep 3D Squeeze-and-excitation U-Nets for fast and fully automated whole-volume anatomical segmentation

H-DenseUNet: Hybrid Densely Connected UNet for Liver and Liver Tumor Segmentation from CT Volumes

參考ziliao :

http://antkillerfarm.github.io/dl/2018/10/26/Deep_Learning_48.html

https://bbs.cvmart.net/topics/1422

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章