【使用keras框架下Resnet101_Unet深度學習模型對醫學核磁圖像語義分割】

盆骨分割模型測試

1說明

本次採用的是vnet2d二分類模型和Resnet101_Unet模型分別進行訓練測試,爲防止樣本分佈不均衡現象,損失函數均採用二分類dice損失,不再使用之前的最原始的二分類交叉熵損失。

二分類dice損失如下:

#二分類dice損失
def dice_coef(y_true, y_pred, smooth):
    #y_pred =K.cast((K.greater(y_pred,thresh)), dtype='float32')#轉換爲float型
    y_true_f =y_true# K.flatten(y_true)
    y_pred_f =y_pred# K.flatten(y_pred)
    # print("y_true_f",y_true_f.shape)
    # print("y_pred_f",y_pred_f.shape)
    intersection = K.sum(y_true_f * y_pred_f,axis=(0,1,2))
    denom =K.sum(y_true_f,axis=(0,1,2)) + K.sum(y_pred_f,axis=(0,1,2))
    return K.mean((2. * intersection + smooth) /(denom + smooth))

def dice_loss(smooth):
    def dice(y_true, y_pred):
        # print("y_true_f",y_true.shape)
        # print("y_pred_f",y_pred.shape) 
        return 1-dice_coef(y_true, y_pred, smooth)
    return dice
    

     Resnet101_Unet模型實現如下,值得注意的是原始Unet網絡相對交淺,特徵提取能力有限,故這裏我們使用Resnet101進行網絡模型加深,提高特徵提取效果。


 # *******************resnet101 unet*********************

def conv3x3(x, out_filters, strides=(1, 1)):
    x = Conv2D(out_filters, 3, padding='same', strides=strides, use_bias=False, kernel_initializer='he_normal')(x)
    return x


def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', use_activation=True):
    x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=3)(x)
    if use_activation:
        x = Activation('relu')(x)
        return x
    else:
        return x


def basic_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False):
    x = conv3x3(input, out_filters, strides)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = conv3x3(x, out_filters)
    x = BatchNormalization(axis=3)(x)

    if with_conv_shortcut:
        residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input)
        residual = BatchNormalization(axis=3)(residual)
        x = add([x, residual])
    else:
        x = add([x, input])

    x = Activation('relu')(x)
    return x


def bottleneck_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False):
    expansion = 4
    de_filters = int(out_filters / expansion)

    x = Conv2D(de_filters, 1, use_bias=False, kernel_initializer='he_normal')(input)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = Conv2D(de_filters, 3, strides=strides, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = Conv2D(out_filters, 1, use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=3)(x)

    if with_conv_shortcut:
        residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input)
        residual = BatchNormalization(axis=3)(residual)
        x = add([x, residual])
    else:
        x = add([x, input])

    x = Activation('relu')(x)
    return x


def unet_resnet_101(height=256, width=256, channel=1, classes=1): 
    input = Input(shape=(height, width, channel))

    conv1_1 = Conv2D(64, 7, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(input)
    conv1_1 = BatchNormalization(axis=3)(conv1_1)
    conv1_1 = Activation('relu')(conv1_1)
    conv1_2 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(conv1_1)

    # conv2_x  1/4
    conv2_1 = bottleneck_Block(conv1_2, 256, strides=(1, 1), with_conv_shortcut=True)
    conv2_2 = bottleneck_Block(conv2_1, 256)
    conv2_3 = bottleneck_Block(conv2_2, 256)

    # conv3_x  1/8
    conv3_1 = bottleneck_Block(conv2_3, 512, strides=(2, 2), with_conv_shortcut=True)
    conv3_2 = bottleneck_Block(conv3_1, 512)
    conv3_3 = bottleneck_Block(conv3_2, 512)
    conv3_4 = bottleneck_Block(conv3_3, 512)

    # conv4_x  1/16
    conv4_1 = bottleneck_Block(conv3_4, 1024, strides=(2, 2), with_conv_shortcut=True)
    conv4_2 = bottleneck_Block(conv4_1, 1024)
    conv4_3 = bottleneck_Block(conv4_2, 1024)
    conv4_4 = bottleneck_Block(conv4_3, 1024)
    conv4_5 = bottleneck_Block(conv4_4, 1024)
    conv4_6 = bottleneck_Block(conv4_5, 1024)
    conv4_7 = bottleneck_Block(conv4_6, 1024)
    conv4_8 = bottleneck_Block(conv4_7, 1024)
    conv4_9 = bottleneck_Block(conv4_8, 1024)
    conv4_10 = bottleneck_Block(conv4_9, 1024)
    conv4_11 = bottleneck_Block(conv4_10, 1024)
    conv4_12 = bottleneck_Block(conv4_11, 1024)
    conv4_13 = bottleneck_Block(conv4_12, 1024)
    conv4_14 = bottleneck_Block(conv4_13, 1024)
    conv4_15 = bottleneck_Block(conv4_14, 1024)
    conv4_16 = bottleneck_Block(conv4_15, 1024)
    conv4_17 = bottleneck_Block(conv4_16, 1024)
    conv4_18 = bottleneck_Block(conv4_17, 1024)
    conv4_19 = bottleneck_Block(conv4_18, 1024)
    conv4_20 = bottleneck_Block(conv4_19, 1024)
    conv4_21 = bottleneck_Block(conv4_20, 1024)
    conv4_22 = bottleneck_Block(conv4_21, 1024)
    conv4_23 = bottleneck_Block(conv4_22, 1024)

    # conv5_x  1/32
    conv5_1 = bottleneck_Block(conv4_23, 2048, strides=(2, 2), with_conv_shortcut=True)
    conv5_2 = bottleneck_Block(conv5_1, 2048)
    conv5_3 = bottleneck_Block(conv5_2, 2048)

    up6 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv5_3), 1024, 2)
    merge6 = concatenate([conv4_23, up6], axis=3)
    conv6 = Conv2d_BN(merge6, 1024, 3)
    conv6 = Conv2d_BN(conv6, 1024, 3)

    up7 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv6), 512, 2)
    merge7 = concatenate([conv3_4, up7], axis=3)
    conv7 = Conv2d_BN(merge7, 512, 3)
    conv7 = Conv2d_BN(conv7, 512, 3)

    up8 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv7), 256, 2)
    merge8 = concatenate([conv2_3, up8], axis=3)
    conv8 = Conv2d_BN(merge8, 256, 3)
    conv8 = Conv2d_BN(conv8, 256, 3)

    up9 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv8), 64, 2)
    merge9 = concatenate([conv1_1, up9], axis=3)
    conv9 = Conv2d_BN(merge9, 64, 3)
    conv9 = Conv2d_BN(conv9, 64, 3)

    up10 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv9), 64, 2)
    conv10 = Conv2d_BN(up10, 64, 3)
    conv10 = Conv2d_BN(conv10, 64, 3)

    conv11 = Conv2d_BN(conv10, classes, 1, use_activation=None)
    activation = Activation('sigmoid', name='Classification')(conv11)
    # conv_out=Conv2D(classes, 1, activation = 'softmax', padding = 'same', kernel_initializer = 'he_normal')(conv11)

    model = Model(inputs=input, outputs=activation)

核磁樣本勾畫33個共使用710個樣本。盆骨基本每個層面都有,樣本數量相對並不是很多,這裏暫不進行數據集拓展,不使用驗證集控制訓練。

Resnet101_Unet模型訓練參數如下表

 

batch_size

epochs

loss

accurancy

參數量(千萬)

Vnet

8

100

0.013

0.999

5.15

 

2訓練過程

2.1 Resnet101_Unet模型

訓練過程發現Resnet101_Unet模型可能因爲待訓練參數數量很大,如下圖1顯示Resnet101_Unet模型參數量將近9千萬,比vnet2d網絡5千萬參數多了將近一倍,導致loss下降更新非常緩慢,在我個人本地主機訓練將近10小時,loss呈現下降趨勢但距離下降到目標期望範圍花費時間較大。Resnet101_Unet模型同樣採用dice損失函數訓練100個epoch,batch_size爲8.

 

 

  

                          圖1 Unet模型訓練截止圖                                                            圖2  Unet模型參數量

 

    

  

                                       圖3 unet訓練準確率曲線                                                   圖4 unet訓練loss下降曲線

2.分割效果

測試選用12個未壓脂核磁數據集,共計263個切片作爲測試樣本。測試集整體預測相對良好

 

 

 

 

3 下一步計劃

本次訓練採用DIce損失來應對樣本不均衡現象,後續可採用generalized dice loss,tversky coefficient loss等新的損失或者Dice+交叉熵複合損失函數作爲指導;另外針對樣本集可採用改變對比度等方法適當增加拓展數據量,對測試分割較差的樣本投入訓練集。

訓練集不開源,代碼後續開源

項目合作qq:2642828613   

圖像處理深度學習交流羣:581148993

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章