ESRGAN:基於GAN的增強超分辨率方法(附代碼解析)

原文鏈接:https://mp.weixin.qq.com/s/zw72I9FTiKq3ZA4gcd5K3A

之前看的文章裏有提到 GAN 在圖像修復時更容易得到符合視覺上效果更好的圖像,所以也是看了一些結合 GAN 的圖像修復工作。

 

ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 發表於 ECCV 2018 的 Workshops,作者在 SRGAN 的基礎上進行了改進,包括改進網絡的結構、判決器的判決形式,以及更換了一個用於計算感知域損失的預訓練網絡。

 

 

 

超分辨率生成對抗網絡(SRGAN)是一項開創性的工作,能夠在單一圖像超分辨率中生成逼真的紋理。這項工作發表於 CVPR 2017。

 

但是,放大後的細節通常伴隨着令人不快的僞影。爲了更進一步地提升視覺質量,作者仔細研究了 SRGAN 的三個關鍵部分:1)網絡結構;2)對抗性損失;3)感知域損失。並對每一項進行改進,得到 ESRGAN。

 

具體而言,文章提出了一種 Residual-in-Residual Dense Block (RRDB) 的網絡單元,在這個單元中,去掉了 BN(Batch Norm)層。此外,作者借鑑了 Relativistic GAN 的想法,讓判別器預測圖像的真實性而不是圖像“是否是 fake 圖像”。

 

最後,文章對感知域損失進行改進,使用激活前的特徵,這樣可以爲亮度一致性和紋理恢復提供更強的監督。在這些改進的幫助下,ESRGAN 得到了更好的視覺質量以及更逼真和自然的紋理。

 

改進後的效果圖(4 倍放大):

 

 ▲ 在紋理和細節上,ESRGAN都優於SRGAN

SRGAN的思考與貢獻

 

現有的超分辨率網絡在不同的網絡結構設計以及訓練策略下,超分辨的效果得到了很大的提升,特別是 PSNR 指標。但是,基於 PSNR 指標的模型會傾向於生成過度平滑的結果,這些結果缺少必要的高頻信息。PSNR 指標與人類觀察者的主觀評價從根本上就不統一。 

 

一些基於感知域信息驅動的方法已經提出來用於提升超分辨率結果的視覺質量。例如,感知域的損失函數提出來用於在特徵空間(instead of 像素空間)中優化超分辨率模型;生成對抗網絡通過鼓勵網絡生成一些更接近於自然圖像的方法來提升超分辨率的質量;語義圖像先驗信息用於進一步改善恢復的紋理細節。 

 

通過結合上面的方法,SRGAN 模型極大地提升了超分辨率結果的視覺質量。但是 SRGAN 模型得到的圖像和 GT 圖像仍有很大的差距。

 

ESRGAN的改進

 

文章對這三點做出改進:

 

1. 網絡的基本單元從基本的殘差單元變爲 Residual-in-Residual Dense Block (RRDB);

 

2. GAN 網絡改進爲 Relativistic average GAN (RaGAN);

 

3. 改進感知域損失函數,使用激活前的 VGG 特徵,這個改進會提供更尖銳的邊緣和更符合視覺的結果。

 

網絡結構及思想

 

生成器部分

 

首先,作者參考 SRResNet 結構作爲整體的網絡結構,SRResNet 的基本結構如下:

 

 SRResNet基本結構

 

爲了提升 SRGAN 重構的圖像質量,作者主要對生成器 G 做出如下改變:

 

1. 去掉所有的 BN 層;

 

2. 把原始的 block 變爲 Residual-in-Residual Dense Block (RRDB),這個 block 結合了多層的殘差網絡和密集連接。

 

如下圖所示:

 

 

思想

 

BN 層的影響:對於不同的基於 PSNR 的任務(包括超分辨率和去模糊)來說,去掉 BN 層已經被證明會提高表現和減小計算複雜度。

 

BN 層在訓練時,使用一個 batch 的數據的均值和方差對該 batch 特徵進行歸一化,在測試時,使用在整個測試集上的數據預測的均值和方差。當訓練集和測試集的統計量有很大不同的時候,BN 層就會傾向於生成不好的僞影,並且限制模型的泛化能力。

 

作者發現,BN 層在網絡比較深,而且在 GAN 框架下進行訓練的時候,更會產生僞影。這些僞影偶爾出現在迭代和不同的設置中,違反了對訓練穩定性能的需求。所以爲了穩定的訓練和一致的性能,作者去掉了 BN 層。此外,去掉 BN 層也能提高模型的泛化能力,減少計算複雜度和內存佔用。

 

Trick

 

除了上述的改進,作者也使用了一些技巧來訓練深層網絡:

 

1. 對殘差信息進行 scaling,即將殘差信息乘以一個 0 到 1 之間的數,用於防止不穩定;

 

2. 更小的初始化,作者發現當初始化參數的方差變小時,殘差結構更容易進行訓練。

 

判別器部分 

 

除了改進的生成器,作者也基於 Relativistic GAN 改進了判別器。判別器 D 使用的網絡是 VGG 網絡,SRGAN 中的判別器 D 用於估計輸入到判別器中的圖像是真實且自然圖像的概率,而 Relativistic 判別器則嘗試估計真實圖像相對來說比 fake 圖像更逼真的概率。 

 

如下圖所示:

 

 

具體而言,作者把標準的判別器換成 Relativistic average Discriminator(RaD),所以判別器的損失函數定義爲:

 

 

對應的生成器的對抗損失函數爲:

 

 

求均值的操作是通過對 mini-batch 中的所有數據求平均得到的,xf 是原始低分辨圖像經過生成器以後的圖像。

 

可以觀察到,對抗損失包含了 xr 和 xf,所以這個生成器受益於對抗訓練中的生成數據和實際數據的梯度,這種調整會使得網絡學習到更尖銳的邊緣和更細節的紋理。

 

感知域損失

 

文章也提出了一個更有效的感知域損失,使用激活前的特徵(VGG16 網絡)。

 

感知域的損失當前是定義在一個預訓練的深度網絡的激活層,這一層中兩個激活了的特徵的距離會被最小化。

 

與此相反,文章使用的特徵是激活前的特徵,這樣會克服兩個缺點。第一,激活後的特徵是非常稀疏的,特別是在很深的網絡中。這種稀疏的激活提供的監督效果是很弱的,會造成性能低下;第二,使用激活後的特徵會導致重建圖像與 GT 的亮度不一致。

 

如圖所示:

 

 使用激活前與激活後的特徵的比較:a. 亮度 b. 細節

 

作者對使用的感知域損失進行了探索。與目前多數使用的用於圖像分類的 VGG 網絡構建的感知域損失相反,作者提出一種更適合於超分辨的感知域損失,這個損失基於一個用於材料識別的 VGG16 網絡(MINCNet),這個網絡更聚焦於紋理而不是物體。儘管這樣帶來的增益很小,但作者仍然相信,探索關注紋理的感知域損失對超分辨至關重要。

 

損失函數

 

經過上面對網絡模塊的定義和構建以後,再定義損失函數,就可以進行訓練了。

 

對於生成器 G,它的損失函數爲:

 

 

即爲感知域損失,作者的源碼取的是 L1 Loss,即爲上面定義的生成器損失,而 L1 則爲 pixel-wise 損失,即, 實驗中取

 

對於判別器,其損失函數就是上面提到的:

 

 

網絡細節

 

生成器網絡G

 

要定義 RDDB 模塊,首先要定義 Dense Block,而 Dense Block 裏面又有卷積層,LReLU 層以及密集連接,所以首先將卷積層和 LReLU 層進行模塊化,這部分的代碼如下(PyTorch):

 

def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
    '''
    Conv layer with padding, normalization, activation
    mode: CNA --> Conv -> Norm -> Act
        NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
    '''
    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0

    c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
            dilation=dilation, bias=bias, groups=groups)
    a = act(act_type) if act_type else None
    if 'CNA' in mode:
        n = norm(norm_type, out_nc) if norm_type else None
        return sequential(p, c, n, a)
    elif mode == 'NAC':
        if norm_type is None and act_type is not None:
            a = act(act_type, inplace=False)
            # Important!
            # input----ReLU(inplace)----Conv--+----output
            #        |________________________|
            # inplace ReLU will modify the input, therefore wrong output
        n = norm(norm_type, in_nc) if norm_type else None
        return sequential(n, a, p, c)

 

注意這裏的 pad_type='zero' 並不是指 padding=0,源碼中定義了兩個函數,針對不同模式下的 padding:

 

def pad(pad_type, padding):
    # helper selecting padding layer
    # if padding is 'zero', do by conv layers
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    else:
        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding

 

所以當 pad_type='zero' 時,執行的是 get_valid_padding 函數,根據輸入參數可知此時 padding=1。 

 

模塊化以後,對 Dense Block 進行定義:

 

class ResidualDenseBlock_5C(nn.Module):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5.mul(0.2) + x

 

前面提到的對殘差信息進行 scaling,在這裏可以看出來,係數爲 0.2。可以看到在 kernel size(3×3)和 stride=1,padding=1 的設置下,特徵圖的大小始終不變,但是通道數由於 concat 的原因,每次都會增加 gc 個通道,但是會在最後一層由變回原來的通道數 nc,這裏的參數 norm_type=None,表示不要 Batch Norm。 

 

定義了 Dense Block 以後,就可以組成 RDDB 了:

 

class RRDB(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)
        self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)
        self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out.mul(0.2) + x

 

因爲特徵圖大小始終不變,所以需要定義上採樣模塊進行放大,得到最後的結果:

 

def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
    # Up conv
    # described in https://distill.pub/2016/deconv-checkerboard/
    upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=norm_type, act_type=act_type)
    return sequential(upsample, conv)

 

參考 SRResNet,還需要一個 Shortcut 連接模塊:

 

class ShortcutBlock(nn.Module):
    #Elementwise sum the output of a submodule to its input
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output

    def __repr__(self):
        tmpstr = 'Identity + \n|'
        modstr = self.sub.__repr__().replace('\n', '\n|')
        tmpstr = tmpstr + modstr
        return tmpstr

 

定義好上面的模塊以後,就可以定義生成器網絡 G(RDDBNet):

 

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x

 

注意到這裏有個參數 nb,這個參數控制網絡中 RDDB 的數量,作者取的是 23。

 

判別器網絡D

 

前面提到,判別器 D 的網絡結構爲 VGG 網絡,定義如下(輸入圖像 size 爲 128×128):

 

# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(nn.Module):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
        super(Discriminator_VGG_128, self).__init__()
        # features
        # hxw, c
        # 128, 64
        conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 64, 64
        conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 32, 128
        conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 16, 256
        conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 8, 512
        conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 4, 512
        self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9)

        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

 

可以看到,這裏使用了 batch norm,層間的激活函數爲 leakyReLU,base_nf參數爲基礎通道數,爲 64。經過特徵提取以後,原本爲 128×128×1(/3) 的輸入圖像輸出爲 4×4×512。再經過其定義的 classifier 得到輸出值。 

 

提取感知域損失的網絡(Perceptual Network) 

 

文章使用了一個用於材料識別的 VGG16 網絡(MINCNet)來提取感知域特徵,定義如下:

 

class MINCNet(nn.Module):
    def __init__(self):
        super(MINCNet, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
        self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
        self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)

    def forward(self, x):
        out = self.ReLU(self.conv11(x))
        out = self.ReLU(self.conv12(out))
        out = self.maxpool1(out)
        out = self.ReLU(self.conv21(out))
        out = self.ReLU(self.conv22(out))
        out = self.maxpool2(out)
        out = self.ReLU(self.conv31(out))
        out = self.ReLU(self.conv32(out))
        out = self.ReLU(self.conv33(out))
        out = self.maxpool3(out)
        out = self.ReLU(self.conv41(out))
        out = self.ReLU(self.conv42(out))
        out = self.ReLU(self.conv43(out))
        out = self.maxpool4(out)
        out = self.ReLU(self.conv51(out))
        out = self.ReLU(self.conv52(out))
        out = self.conv53(out)
        return out

 

再引入預訓練參數,就可以進行特徵提取:

 

class MINCFeatureExtractor(nn.Module):
    def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
                device=torch.device('cpu')):
        super(MINCFeatureExtractor, self).__init__()

        self.features = MINCNet()
        self.features.load_state_dict(
            torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
        self.features.eval()
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        output = self.features(x)
        return output

 

網絡插值思想

 

爲了平衡感知質量和 PSNR 等評價值,作者提出了一個靈活且有效的方法——網絡插值。具體而言,作者首先基於 PSNR 方法訓練的得到的網絡 G_PSNR,然後再用基於 GAN 的網絡 G_GAN 進行 finetune。

 

然後,對這兩個網絡相應的網絡參數進行插值得到一個插值後的網絡 G_INTERP:

 

 

這樣就可以通過 α 值來調整效果。

 

訓練細節

 

放大倍數:4;mini-batch:16。

 

通過 Matlab 的 bicubic 函數對 HR 圖像進行降採樣得到 LR 圖像。 

 

HR patch 大小:128×128。實驗發現使用大的 patch 時,訓練一個深層網絡效果會更好,因爲一個增大的感受域會幫助模型捕捉更具有語義的信息。

 

訓練過程如下: 

 

1. 訓練一個基於 PSNR 指標的模型(L1 Loss),初始化學習率:2×1e-4,每 200000 個 mini-batch 學習率除以 2;

 

2. 以 1 中訓練的模型作爲生成器的初始化。

 

 

初始學習率:1e-4,並在 50k,100k,200k,300k 迭代後減半。

 

一個基於像素損失函數進行優化的預訓練模型會幫助基於 GAN 的模型生成更符合視覺的結果,原因如下:

 

1. 可以避免生成器不希望的局部最優;

 

2. 再預訓練以後,判別器所得到的輸入圖像的質量是相對較好的,而不是完全初始化的圖像,這樣會使判別器更關注到紋理的判別。

 

優化器:Adam (β1=0.9, β2=0.999);交替更新生成器和判別器,直到收斂。 

 

生成器的設置:1.16 層(基本的殘差結構);2.23層(RDDB)。

 

數據集:DIV2K,Flickr2K,OST(有豐富紋理信息的數據集會是模型產生更自然的結果)。

 

對比實驗(4倍放大)

 

針對文中提到的各種改進,包括移除 BN,使用激活前特徵作爲感知域特徵,修改 GAN 的判別條件,以及提出的 RDDB,作者做了詳細的對比試驗,結果如下:

 

 

經過實驗以後,作者得出結論: 

 

1. 去掉 BN:並沒有降低網絡的性能,而且節省了計算資源和內存佔用。而且發現當網絡變深變複雜時,帶 BN 層的模型更傾向於產生影響視覺效果的僞影;

 

2. 使用激活前的特徵:得到的圖像的亮度更準確,而且可以產生更尖銳的邊緣和更豐富的細節;

 

3. RaGAN:產生更尖銳的邊緣和更豐富的細節;

 

4. RDDB:更加提升恢復得到的紋理(因爲深度模型具有強大的表示能力來捕獲語義信息),而且可以去除噪聲。

 

網絡插值實驗 

 

爲了平衡視覺效果和 PSNR 等性能指標,作者對網絡插值參數 α 的取值進行了實驗,結果如下:

 

 

此外,作者還對比了網絡插值和圖像插值的效果。圖像插值即指分別由兩個網絡輸出的圖像進行插值。通過對比實驗可以看到,圖像插值對消除僞影的效果不如網絡插值。

 

與SOTA方法對比(4倍放大)

 

 

可以看到,ESRGAN 得到的圖像 PSNR 值不高,但是從視覺效果上看會更好,Percpetual Index 值更小(越小越好),而且 ESRGAN 在 PIRM-SR 競賽上也獲得了第一名(在 Percpetual Index 指標上)。

 

總結

 

文章提出的 ESRGAN 在 SRGAN 的基礎上做出了改進,包括去除 BN 層,基本結構換成 RDDB,改進 GAN 中判別器的判別目標,以及使用激活前的特徵構成感知域損失函數,實驗證明這些改進對提升輸出圖像的視覺效果都有作用。

 

此外,作者也使用了一些技巧來提升網絡的性能,包括對殘差信息的 scaling,以及更小的初始化。最後,作者使用了一種網絡插值的方法來平衡輸出圖像的視覺效果和 PSNR 等指標值。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章