作者信息:
Robin Brügger, CV Lab,ETH Zürich
醫療影像常用3D網絡,顯存佔用經常制約了網絡結構與深度,從而對最終精度產生影響。文章主要借鑑了reversible block 的思路來解決上述問題。
reversible block
該block設計很巧妙。輸入x 按通道數先分成兩組,x1, x2。利用如下公式(1),得到y1,y2,由於特殊的結構設計,x1,x2反過來又可以由公式(2) 通過y1,y2計算得到。
網絡訓練時顯存佔用很大一部分是儲存前向傳播的中間結果(因爲反向傳播時需要用到),使用 reversible block 後,中間結果無需保存,只要保存最後輸出的結果,中間結果都可以反推得到。
Method
文章基於MICCAI Brats18挑戰賽第二名 No-New-Net 的結構進行改進,引入reversible block後的網絡結構如下:
Results
結果很好,第一二行比較可以看到使用reversible block後,顯存節約2.5G,使得在12G顯存下使用full volume 訓練成爲可能,與No-New-Net的單模型比也要強。
代碼
reversible block模塊部分的代碼如下,反向傳播的代碼花了一定時間才大致瞭解。f.backward(dy)
是鏈式法則的意思:把f.backward()
得到的梯度乘上之前層反傳得到的梯度dy
,可以參考這個資料
import torch
import torch.nn as nn
#import torch.autograd.function as func
class ReversibleBlock(nn.Module):
'''
Elementary building block for building (partially) reversible architectures
Implementation of the Reversible block described in the RevNet paper
(https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence`
for autograd support.
Arguments:
f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''
def __init__(self, f_block, g_block):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
def forward(self, x):
"""
Performs the forward pass of the reversible block. Does not record any gradients.
:param x: Input tensor. Must be splittable along dimension 1.
:return: Output tensor of the same shape as the input tensor
"""
x1, x2 = torch.chunk(x, 2, dim=1)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f_block(x2)
y2 = x2 + self.g_block(y1)
return torch.cat([y1, y2], dim=1)
def backward_pass(self, y, dy):
"""
Performs the backward pass of the reversible block.
Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of the
forward pass and its gradients.
:param y: Outputs of the reversible block
:param dy: Derivatives of the outputs
:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus.
"""
# Split the arguments channel-wise
y1, y2 = torch.chunk(y, 2, dim=1)
del y
assert (not y1.requires_grad), "y1 must already be detached"
assert (not y2.requires_grad), "y2 must already be detached"
dy1, dy2 = torch.chunk(dy, 2, dim=1)
del dy
assert (not dy1.requires_grad), "dy1 must not require grad"
assert (not dy2.requires_grad), "dy2 must not require grad"
# Enable autograd for y1 and y2. This ensures that PyTorch
# keeps track of ops. that use y1 and y2 as inputs in a DAG
y1.requires_grad = True
y2.requires_grad = True
# Ensures that PyTorch tracks the operations in a DAG
with torch.enable_grad():
gy1 = self.g_block(y1)
# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of G are set as a side effect
gy1.backward(dy2)
with torch.no_grad():
x2 = y2 - gy1 # Restore first input of forward()
del y2, gy1
# The gradient of x1 is the sum of the gradient of the output
# y1 as well as the gradient that flows back through G
# (The gradient that flows back through G is stored in y1.grad)
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f_block(x2)
# Use autograd framework to differentiate the calculation. The
# derivatives of the parameters of F are set as a side effec
fx2.backward(dx1)
with torch.no_grad():
x1 = y1 - fx2 # Restore second input of forward()
del y1, fx2
# The gradient of x2 is the sum of the gradient of the output
# y2 as well as the gradient that flows back through F
# (The gradient that flows back through F is stored in x2.grad)
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
# Undo the channelwise split
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
return x, dx
我的筆記
我覺得這篇文章思路很棒,一是本文針對到了醫療影像處理的一個痛點,即顯存佔用。大部分研究者顯存受限,12G爲最常用的設備。二是他引入了其他領域的reversible block的思路,該問題提出了一個解決思路,並且最終的實驗結果也很好。本文對我的研究思路有很好的啓發。
當然,節約的顯存是以更長的訓練時間爲代價的。