[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation

作者信息:
Robin Brügger, CV Lab,ETH Zürich


醫療影像常用3D網絡,顯存佔用經常制約了網絡結構與深度,從而對最終精度產生影響。文章主要借鑑了reversible block 的思路來解決上述問題。

reversible block

該block設計很巧妙。輸入x 按通道數先分成兩組,x1, x2。利用如下公式(1),得到y1,y2,由於特殊的結構設計,x1,x2反過來又可以由公式(2) 通過y1,y2計算得到。
在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述
網絡訓練時顯存佔用很大一部分是儲存前向傳播的中間結果(因爲反向傳播時需要用到),使用 reversible block 後,中間結果無需保存,只要保存最後輸出的結果,中間結果都可以反推得到。
在這裏插入圖片描述

Method

文章基於MICCAI Brats18挑戰賽第二名 No-New-Net 的結構進行改進,引入reversible block後的網絡結構如下:
在這裏插入圖片描述

Results

結果很好,第一二行比較可以看到使用reversible block後,顯存節約2.5G,使得在12G顯存下使用full volume 訓練成爲可能,與No-New-Net的單模型比也要強。
在這裏插入圖片描述

代碼

reversible block模塊部分的代碼如下,反向傳播的代碼花了一定時間才大致瞭解。f.backward(dy) 是鏈式法則的意思:把f.backward()得到的梯度乘上之前層反傳得到的梯度dy,可以參考這個資料

import torch
import torch.nn as nn
#import torch.autograd.function as func

class ReversibleBlock(nn.Module):
    '''
    Elementary building block for building (partially) reversible architectures
    Implementation of the Reversible block described in the RevNet paper
    (https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence`
    for autograd support.
    Arguments:
        f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
        g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
    '''

    def __init__(self, f_block, g_block):
        super(ReversibleBlock, self).__init__()
        self.f_block = f_block
        self.g_block = g_block

    def forward(self, x):
        """
        Performs the forward pass of the reversible block. Does not record any gradients.
        :param x: Input tensor. Must be splittable along dimension 1.
        :return: Output tensor of the same shape as the input tensor
        """
        x1, x2 = torch.chunk(x, 2, dim=1)
        y1, y2 = None, None
        with torch.no_grad():
            y1 = x1 + self.f_block(x2)
            y2 = x2 + self.g_block(y1)

        return torch.cat([y1, y2], dim=1)

    def backward_pass(self, y, dy):
        """
        Performs the backward pass of the reversible block.
        Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of the
        forward pass and its gradients.
        :param y: Outputs of the reversible block
        :param dy: Derivatives of the outputs
        :return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
        """
        
        # Split the arguments channel-wise
        y1, y2 = torch.chunk(y, 2, dim=1)
        del y
        assert (not y1.requires_grad), "y1 must already be detached"
        assert (not y2.requires_grad), "y2 must already be detached"
        dy1, dy2 = torch.chunk(dy, 2, dim=1)
        del dy
        assert (not dy1.requires_grad), "dy1 must not require grad"
        assert (not dy2.requires_grad), "dy2 must not require grad"

        # Enable autograd for y1 and y2. This ensures that PyTorch
        # keeps track of ops. that use y1 and y2 as inputs in a DAG
        y1.requires_grad = True
        y2.requires_grad = True

        # Ensures that PyTorch tracks the operations in a DAG
        with torch.enable_grad():
            gy1 = self.g_block(y1)

            # Use autograd framework to differentiate the calculation. The
            # derivatives of the parameters of G are set as a side effect
            gy1.backward(dy2)

        with torch.no_grad():
            x2 = y2 - gy1 # Restore first input of forward()
            del y2, gy1

            # The gradient of x1 is the sum of the gradient of the output
            # y1 as well as the gradient that flows back through G
            # (The gradient that flows back through G is stored in y1.grad)
            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f_block(x2)

            # Use autograd framework to differentiate the calculation. The
            # derivatives of the parameters of F are set as a side effec
            fx2.backward(dx1)

        with torch.no_grad():
            x1 = y1 - fx2 # Restore second input of forward()
            del y1, fx2

            # The gradient of x2 is the sum of the gradient of the output
            # y2 as well as the gradient that flows back through F
            # (The gradient that flows back through F is stored in x2.grad)
            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            # Undo the channelwise split
            x = torch.cat([x1, x2.detach()], dim=1)
            dx = torch.cat([dx1, dx2], dim=1)

        return x, dx

我的筆記

我覺得這篇文章思路很棒,一是本文針對到了醫療影像處理的一個痛點,即顯存佔用。大部分研究者顯存受限,12G爲最常用的設備。二是他引入了其他領域的reversible block的思路,該問題提出了一個解決思路,並且最終的實驗結果也很好。本文對我的研究思路有很好的啓發。
當然,節約的顯存是以更長的訓練時間爲代價的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章