盆骨分割模型
前言
U-Net和FCN非常的相似,U-Net比FCN稍晚提出來,但都發表在2015年,和FCN相比,U-Net的第一個特點是完全對稱,也就是左邊和右邊是很類似的,而FCN的decoder相對簡單,只用了一個deconvolution的操作,之後並沒有跟上卷積結構。第二個區別就是skip connection,FCN用的是加操作(summation),U-Net用的是疊操作(concatenation)。這些都是細節,重點是它們的結構用了一個比較經典的思路,也就是編碼和解碼(encoder-decoder),早在2006年就被Hinton大神提出來發表在了nature上.
這個網紅結構,我們先提取出它的拓撲結構,這樣會比較容易分析它的實質,排除很多細節的干擾。輸入是一幅圖,輸出是目標的分割結果。繼續簡化就是,一幅圖,編碼,或者說降採樣,然後解碼,也就是升採樣,然後輸出一個分割結果。根據結果和真實分割的差異,反向傳播來訓練這個分割網絡。我們可以說,U-Net裏面最精彩的部分就是這三部分:
- 下采樣
- 上採樣
- skip connection
拿U-Net來說,原論文給出的結構是原圖經過四次降採樣,四次上採樣,得到分割結果,實際呢,爲什麼四次?就是作者喜歡唄,或者說當時作者使用的數據集,四次降採樣的效果好;Unet中對於需要多深的問題。其實這個是非常靈活的,涉及到的一個點就是特徵提取器,各種在encoder上的微創新絡繹不絕,最直接的就是用ImageNet裏面的明星結構來套嘛,前幾年的BottleNeck,VGG16,Residual,還有去年的DenseNet,就比誰出文章快。這一類的論文就相當於從1到10的遞進,而U-Net這個低層結構的提出卻是從0到1。說特徵提取器是dense block,名字也就是DenseUNet,或者是residual block效果好,然後名字也就是ResUNet。其他的大家可以看這篇文章:https://zhuanlan.zhihu.com/p/44958351
keras中如何定義殘差網絡:https://blog.csdn.net/Tourior/article/details/83824436
1說明
本次採用的是vnet2d二分類模型和Resnet101_Unet模型分別進行訓練測試,爲防止樣本分佈不均衡現象,損失函數均採用二分類dice損失,不再使用之前的最原始的二分類交叉熵損失。
二分類dice損失如下:
Resnet101_Unet模型實現如下,值得注意的是原始Unet網絡相對交淺,特徵提取能力有限,故這裏我們使用Resnet101進行網絡模型加深,提高特徵提取效果。
# *******************resnet101 unet*********************
def conv3x3(x, out_filters, strides=(1, 1)):
x = Conv2D(out_filters, 3, padding='same', strides=strides, use_bias=False, kernel_initializer='he_normal')(x)
return x
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', use_activation=True):
x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, kernel_initializer='he_normal')(x)
x = BatchNormalization(axis=3)(x)
if use_activation:
x = Activation('relu')(x)
return x
else:
return x
def basic_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False):
x = conv3x3(input, out_filters, strides)
x = BatchNormalization(axis=3)(x)
x = Activation('relu')(x)
x = conv3x3(x, out_filters)
x = BatchNormalization(axis=3)(x)
if with_conv_shortcut:
residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input)
residual = BatchNormalization(axis=3)(residual)
x = add([x, residual])
else:
x = add([x, input])
x = Activation('relu')(x)
return x
def bottleneck_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False):
expansion = 4
de_filters = int(out_filters / expansion)
x = Conv2D(de_filters, 1, use_bias=False, kernel_initializer='he_normal')(input)
x = BatchNormalization(axis=3)(x)
x = Activation('relu')(x)
x = Conv2D(de_filters, 3, strides=strides, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
x = BatchNormalization(axis=3)(x)
x = Activation('relu')(x)
x = Conv2D(out_filters, 1, use_bias=False, kernel_initializer='he_normal')(x)
x = BatchNormalization(axis=3)(x)
if with_conv_shortcut:
residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input)
residual = BatchNormalization(axis=3)(residual)
x = add([x, residual])
else:
x = add([x, input])
x = Activation('relu')(x)
return x
def unet_resnet_101(height=320, width=320, channel=1, classes=3):
input = Input(shape=(height, width, channel))
conv1_1 = Conv2D(64, 7, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(input)
conv1_1 = BatchNormalization(axis=3)(conv1_1)
conv1_1 = Activation('relu')(conv1_1)
conv1_2 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(conv1_1)
# conv2_x 1/4
conv2_1 = bottleneck_Block(conv1_2, 256, strides=(1, 1), with_conv_shortcut=True)
conv2_2 = bottleneck_Block(conv2_1, 256)
conv2_3 = bottleneck_Block(conv2_2, 256)
# conv3_x 1/8
conv3_1 = bottleneck_Block(conv2_3, 512, strides=(2, 2), with_conv_shortcut=True)
conv3_2 = bottleneck_Block(conv3_1, 512)
conv3_3 = bottleneck_Block(conv3_2, 512)
conv3_4 = bottleneck_Block(conv3_3, 512)
# conv4_x 1/16
conv4_1 = bottleneck_Block(conv3_4, 1024, strides=(2, 2), with_conv_shortcut=True)
conv4_2 = bottleneck_Block(conv4_1, 1024)
conv4_3 = bottleneck_Block(conv4_2, 1024)
conv4_4 = bottleneck_Block(conv4_3, 1024)
conv4_5 = bottleneck_Block(conv4_4, 1024)
conv4_6 = bottleneck_Block(conv4_5, 1024)
conv4_7 = bottleneck_Block(conv4_6, 1024)
conv4_8 = bottleneck_Block(conv4_7, 1024)
conv4_9 = bottleneck_Block(conv4_8, 1024)
conv4_10 = bottleneck_Block(conv4_9, 1024)
conv4_11 = bottleneck_Block(conv4_10, 1024)
conv4_12 = bottleneck_Block(conv4_11, 1024)
conv4_13 = bottleneck_Block(conv4_12, 1024)
conv4_14 = bottleneck_Block(conv4_13, 1024)
conv4_15 = bottleneck_Block(conv4_14, 1024)
conv4_16 = bottleneck_Block(conv4_15, 1024)
conv4_17 = bottleneck_Block(conv4_16, 1024)
conv4_18 = bottleneck_Block(conv4_17, 1024)
conv4_19 = bottleneck_Block(conv4_18, 1024)
conv4_20 = bottleneck_Block(conv4_19, 1024)
conv4_21 = bottleneck_Block(conv4_20, 1024)
conv4_22 = bottleneck_Block(conv4_21, 1024)
conv4_23 = bottleneck_Block(conv4_22, 1024)
# conv5_x 1/32
conv5_1 = bottleneck_Block(conv4_23, 2048, strides=(2, 2), with_conv_shortcut=True)
conv5_2 = bottleneck_Block(conv5_1, 2048)
conv5_3 = bottleneck_Block(conv5_2, 2048)
up6 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv5_3), 1024, 2)
merge6 = concatenate([conv4_23, up6], axis=3)
conv6 = Conv2d_BN(merge6, 1024, 3)
conv6 = Conv2d_BN(conv6, 1024, 3)
up7 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv6), 512, 2)
merge7 = concatenate([conv3_4, up7], axis=3)
conv7 = Conv2d_BN(merge7, 512, 3)
conv7 = Conv2d_BN(conv7, 512, 3)
up8 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv7), 256, 2)
merge8 = concatenate([conv2_3, up8], axis=3)
conv8 = Conv2d_BN(merge8, 256, 3)
conv8 = Conv2d_BN(conv8, 256, 3)
up9 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv8), 64, 2)
merge9 = concatenate([conv1_1, up9], axis=3)
conv9 = Conv2d_BN(merge9, 64, 3)
conv9 = Conv2d_BN(conv9, 64, 3)
up10 = Conv2d_BN(UpSampling2D(size=(2, 2))(conv9), 64, 2)
conv10 = Conv2d_BN(up10, 64, 3)
conv10 = Conv2d_BN(conv10, 64, 3)
conv11 = Conv2d_BN(conv10, classes, 1, use_activation=None)
activation = Activation('softmax', name='Classification')(conv11)
# conv_out=Conv2D(classes, 1, activation = 'softmax', padding = 'same', kernel_initializer = 'he_normal')(conv11)
model = Model(inputs=input, outputs=activation)
# print(model.output_shape) compounded_loss
# model_dice=dice_p_bce
# model_dice=compounded_loss(smooth=0.0005,gamma=2., alpha=0.25)
# model_dice=tversky_coef_loss_fun(alpha=0.3,beta=0.7)
# model_dice=dice_coef_loss_fun(smooth=1e-5)
# model.compile(optimizer = Nadam(lr = 2e-4), loss = model_dice, metrics = ['accuracy'])
#不使用metric
# model_dice=focal_loss(alpha=.25, gamma=2)
# model.compile(optimizer = Adam(lr = 2e-5),loss=dice_coef,metrics=['accuracy'])
model.compile(optimizer = Nadam(lr = 2e-5), loss = focal_lossm,metrics=['accuracy'])
# model.compile(optimizer = Nadam(lr = 2e-4), loss = "categorical_crossentropy",metrics=['accuracy'])
return model
核磁樣本勾畫33個共使用710個樣本。盆骨基本每個層面都有,樣本數量相對並不是很多,這裏暫不進行數據集拓展,不使用驗證集控制訓練。
Resnet101_Unet模型訓練參數如下表
|
batch_size |
epochs |
loss |
accurancy |
參數量(千萬) |
Vnet |
8 |
100 |
0.013 |
0.999 |
5.15 |
2訓練過程
2.1 Resnet101_Unet模型
訓練過程發現Resnet101_Unet模型可能因爲待訓練參數數量很大,如下圖1顯示Resnet101_Unet模型參數量將近9千萬,比vnet2d網絡5千萬參數多了將近一倍,導致loss下降更新非常緩慢,在我個人本地主機訓練將近10小時,loss呈現下降趨勢但距離下降到目標期望範圍花費時間較大。Resnet101_Unet模型同樣採用dice損失函數訓練100個epoch,batch_size爲8.
圖1 Unet模型訓練截止圖 圖2 Unet模型參數量
圖3 unet訓練準確率曲線 圖4 unet訓練loss下降曲線
2.分割效果
測試選用12個未壓脂核磁數據集,共計263個切片作爲測試樣本。測試集整體預測相對良好
3 下一步計劃
1.損失函數改進:
本次訓練採用DIce損失來應對樣本不均衡現象,後續可採用generalized dice loss,tversky coefficient loss等新的損失或者Dice+交叉熵複合損失函數作爲指導;另外針對樣本集可採用改變對比度等方法適當增加拓展數據量,對測試分割較差的樣本投入訓練集。
2. 特徵提取網絡的改進:
目前採用的是Resnet101做骨幹提取網絡,後續打算測試Densenet網絡。
訓練集不開源,代碼後續開源