PSGAN 網絡再修改

PSGAN 網絡再修改

在測試了PSGAN的效果之後,感覺對於不少測試而言效果很好,但還有一些痛點,比如當參考圖是側臉時,左右眼由於光影的原因在妝容的濃淡上會有較大差異,這也會反應在原圖是正臉的左右眼上。

如下所示:
在這裏插入圖片描述
在這裏插入圖片描述
PSGAN中計算attention map是用的像素相對距離Diff_A和diff_B造成的一點不好的影響就是:右圖嘴脣部分確實和左圖非常像,但嘴脣左右邊緣處沒有塗上口紅,可能是給relative_position的attention weight太大了。

如下所示:
在這裏插入圖片描述
在這裏插入圖片描述

基於上面的問題,我對網絡結構和訓練代碼進行了一些修改,訓練計算makeup loss時不再將左右眼的makeup分開計算,而是同時計算雙眼的histogram makeup loss;對於網絡結構,直接通過兩張feature map計算attention map,先訓練一下看看結果有無改善。

網絡結構代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

from ops.spectral_norm import spectral_norm as SpectralNorm


# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input

class ResidualBlock(nn.Module):
    """Residual Block."""

    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True))

    def forward(self, x):
        return x + self.main(x)


class Discriminator(nn.Module):
    """Discriminator. PatchGAN."""

    def __init__(self, image_size=128, conv_dim=64, repeat_num=3, norm='SN'):
        super(Discriminator, self).__init__()

        layers = []
        if norm == 'SN':
            layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
        else:
            layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            if norm == 'SN':
                layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
            else:
                layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        # k_size = int(image_size / np.power(2, repeat_num))
        if norm == 'SN':
            layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
        else:
            layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim * 2

        self.main = nn.Sequential(*layers)
        if norm == 'SN':
            self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
        else:
            self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)

        # conv1 remain the last square size, 256*256-->30*30
        # self.conv2 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=k_size, bias=False))
        # conv2 output a single number

    def forward(self, x):
        h = self.main(x)
        out_makeup = self.conv1(h)
        return out_makeup.squeeze()


class VGG(nn.Module):
    def __init__(self, pool='max'):
        super(VGG, self).__init__()
        # vgg modules
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        if pool == 'max':
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x, out_keys):
        out = {}
        out['r11'] = F.relu(self.conv1_1(x))
        out['r12'] = F.relu(self.conv1_2(out['r11']))
        out['p1'] = self.pool1(out['r12'])
        out['r21'] = F.relu(self.conv2_1(out['p1']))
        out['r22'] = F.relu(self.conv2_2(out['r21']))
        out['p2'] = self.pool2(out['r22'])
        out['r31'] = F.relu(self.conv3_1(out['p2']))
        out['r32'] = F.relu(self.conv3_2(out['r31']))
        out['r33'] = F.relu(self.conv3_3(out['r32']))
        out['r34'] = F.relu(self.conv3_4(out['r33']))
        out['p3'] = self.pool3(out['r34'])
        out['r41'] = F.relu(self.conv4_1(out['p3']))

        out['r42'] = F.relu(self.conv4_2(out['r41']))
        out['r43'] = F.relu(self.conv4_3(out['r42']))
        out['r44'] = F.relu(self.conv4_4(out['r43']))
        out['p4'] = self.pool4(out['r44'])
        out['r51'] = F.relu(self.conv5_1(out['p4']))
        out['r52'] = F.relu(self.conv5_2(out['r51']))
        out['r53'] = F.relu(self.conv5_3(out['r52']))
        out['r54'] = F.relu(self.conv5_4(out['r53']))
        out['p5'] = self.pool5(out['r54'])

        return [out[key] for key in out_keys]


# Makeup Apply Network(MANet)
class Generator(nn.Module):
    """Generator. Encoder-Decoder Architecture."""

    def __init__(self, conv_dim=64, repeat_num=6):
        super(Generator, self).__init__()

        encoder_layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),
                          nn.InstanceNorm2d(conv_dim, affine=False), nn.ReLU(inplace=True)]
        # MANet設置沒有affine

        # Down-Sampling
        curr_dim = conv_dim
        for i in range(2):
            encoder_layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))
            encoder_layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=False))
            encoder_layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck
        for i in range(3):
            encoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        decoder_layers = []
        for i in range(3):
            decoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        # Up-Sampling
        for i in range(2):
            decoder_layers.append(
                nn.ConvTranspose2d(curr_dim, curr_dim // 2, kernel_size=4, stride=2, padding=1, bias=False))
            decoder_layers.append(nn.InstanceNorm2d(curr_dim // 2, affine=True))
            decoder_layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        decoder_layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        decoder_layers.append(nn.Tanh())

        self.encoder = nn.Sequential(*encoder_layers)
        self.decoder = nn.Sequential(*decoder_layers)
        self.MDNet = MDNet()
        self.AMM = AMM()

    def forward(self, source_image, reference_image, mask_source, mask_ref, gamma=None, beta=None, ret=False,
                mode='train'):
        fm_source = self.encoder(source_image)
        fm_reference = self.MDNet(reference_image)
        if ret:
            gamma, beta = self.AMM(fm_source, fm_reference, mask_source, mask_ref, gamma=gamma, beta=beta, ret=ret,
                                   mode=mode)
            return [gamma, beta]
        morphed_fm = self.AMM(fm_source, fm_reference, mask_source, mask_ref, gamma=gamma, beta=beta, ret=ret,
                              mode=mode)
        result = self.decoder(morphed_fm)
        return result


class MDNet(nn.Module):
    """Generator. Encoder-Decoder Architecture."""

    # MDNet is similar to the encoder of StarGAN
    def __init__(self, conv_dim=64, repeat_num=3):
        super(MDNet, self).__init__()

        layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),
                  nn.InstanceNorm2d(conv_dim, affine=True), nn.ReLU(inplace=True)]

        # Down-Sampling
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
        self.main = nn.Sequential(*layers)

    def forward(self, reference_image):
        fm_reference = self.main(reference_image)
        return fm_reference


class AMM(nn.Module):
    """Attentive Makeup Morphing module"""

    def __init__(self):
        super(AMM, self).__init__()
        self.gamma_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)
        self.beta_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)

    def forward(self, fm_source, fm_reference, mask_source, mask_ref, gamma=None, beta=None, ret=False, mode='train'):
        old_gamma_matrix = self.gamma_matrix_conv(fm_reference)
        old_beta_matrix = self.beta_matrix_conv(fm_reference)

        old_gamma_matrix_source = self.gamma_matrix_conv(fm_source)
        old_beta_matrix_source = self.beta_matrix_conv(fm_source)

        if gamma is None:
            attention_map = self.raw_attention_map(fm_source, fm_reference)
            # gamma, beta = self.raw_atten_feature(mask_source, attention_map, old_gamma_matrix, old_beta_matrix,
            #                                      old_gamma_matrix_source, old_beta_matrix_source)
            gamma, beta = self.pure_atten_feature(attention_map, old_gamma_matrix, old_beta_matrix)
            if ret:
                return [gamma, beta]

        morphed_fm_source = fm_source * (1 + gamma) + beta

        return morphed_fm_source

    @staticmethod
    def raw_attention_map(fm_source, fm_reference):
        batch_size, channels, width, height = fm_reference.size()

        # reshape後fm的形狀是C*(H*W)
        temp_fm_reference = fm_reference.view(batch_size, -1, height * width)
        # fm_source 在reshape後需要transpose成(H*W)*C
        temp_fm_source = fm_source.view(batch_size, -1, height * width).permute(0, 2, 1)
        # energy的形狀應該是N*N,N=H*W
        energy = torch.bmm(temp_fm_source, temp_fm_reference)
        energy *= 200  # hyper parameters for visual feature
        attention_map = F.softmax(energy, dim=-1)

        return attention_map

    @staticmethod
    def raw_atten_feature(mask_source, attention_map, old_gamma_matrix, old_beta_matrix, old_gamma_matrix_source,
                          old_beta_matrix_source):
        batch_size, channels, width, height = old_gamma_matrix.size()
        old_gamma_matrix = old_gamma_matrix.view(batch_size, -1, width * height)
        old_beta_matrix = old_beta_matrix.view(batch_size, -1, width * height)
        new_gamma_matrix = torch.bmm(old_gamma_matrix, attention_map.permute(0, 2, 1))
        new_beta_matrix = torch.bmm(old_beta_matrix, attention_map.permute(0, 2, 1))
        new_gamma_matrix = new_gamma_matrix.view(-1, 1, width, height)
        new_beta_matrix = new_beta_matrix.view(-1, 1, width, height)

        reverse_mask_source = 1 - mask_source
        new_mask_source = F.interpolate(mask_source, size=new_gamma_matrix.shape[2:]).repeat(1, channels, 1, 1)
        new_reverse_mask_source = F.interpolate(reverse_mask_source, size=new_gamma_matrix.shape[2:]).repeat(1,
                                                                                                             channels,
                                                                                                             1, 1)
        gamma = new_gamma_matrix * new_mask_source + old_gamma_matrix_source * new_reverse_mask_source
        beta = new_beta_matrix * new_mask_source + old_beta_matrix_source * new_reverse_mask_source
        return gamma, beta

    # 只通過計算兩個feature map的attention來修改gamma_matrix
    @staticmethod
    def pure_atten_feature(attention_map, old_gamma_matrix, old_beta_matrix):
        batch_size, channels, width, height = old_gamma_matrix.size()
        old_gamma_matrix = old_gamma_matrix.view(batch_size, -1, width * height)
        old_beta_matrix = old_beta_matrix.view(batch_size, -1, width * height)
        new_gamma_matrix = torch.bmm(old_gamma_matrix, attention_map.permute(0, 2, 1))
        new_beta_matrix = torch.bmm(old_beta_matrix, attention_map.permute(0, 2, 1))
        new_gamma_matrix = new_gamma_matrix.view(-1, 1, width, height)
        new_beta_matrix = new_beta_matrix.view(-1, 1, width, height)

        gamma = new_gamma_matrix
        beta = new_beta_matrix
        return gamma, beta

    # 下面是PSGAN中計算attention的方法,但需要的顯存太大,而且感覺有一些不合理的地方
    @staticmethod
    def get_attention_map(mask_source, mask_ref, fm_source, fm_reference, mode='train'):
        HW = 64 * 64
        batch_size = 3

        # get 3 part fea using mask
        channels = fm_reference.shape[1]

        mask_source_re = F.interpolate(mask_source, size=64).repeat(1, channels, 1, 1)  # (3, c, h, w)
        fm_source = fm_source.repeat(3, 1, 1, 1)  # (3, c, h, w)
        # 計算 Attention 時 we only consider the pixels belonging to same facial region.
        fm_source = fm_source * mask_source_re  # (3, c, h, w) 3 stands for 3 parts

        mask_ref_re = F.interpolate(mask_ref, size=64).repeat(1, channels, 1, 1)
        fm_reference = fm_reference.repeat(3, 1, 1, 1)
        fm_reference = fm_reference * mask_ref_re

        theta_input = fm_source
        phi_input = fm_reference

        theta_target = theta_input.view(batch_size, -1, HW)
        theta_target = theta_target.permute(0, 2, 1)

        phi_source = phi_input.view(batch_size, -1, HW)

        weight = torch.bmm(theta_target, phi_source)  # (3, HW, HW)
        if mode == 'train':
            weight = weight.cpu()
            weight_ind = torch.LongTensor(weight.detach().numpy().nonzero())
            weight = weight.cuda()
            weight_ind = weight_ind.cuda()
        else:
            weight_ind = torch.LongTensor(weight.numpy().nonzero())
        weight *= 200  # hyper parameters for visual feature
        weight = F.softmax(weight, dim=-1)
        weight = weight[weight_ind[0], weight_ind[1], weight_ind[2]]
        return torch.sparse.FloatTensor(weight_ind, weight, torch.Size([3, HW, HW]))

    @staticmethod
    def atten_feature(mask_ref, attention_map, old_gamma_matrix, old_beta_matrix):
        # 論文中有說gamma和beta的想法源於style transfer,但不是general style transfer,所以這裏要用mask計算每個facial region的style
        batch_size, channels, width, height = old_gamma_matrix.size()

        mask_ref_re = F.interpolate(mask_ref, size=old_gamma_matrix.shape[2:]).repeat(1, channels, 1, 1)
        gamma_ref_re = old_gamma_matrix.repeat(3, 1, 1, 1)
        old_gamma_matrix = gamma_ref_re * mask_ref_re  # (3, c, h, w)
        beta_ref_re = old_beta_matrix.repeat(3, 1, 1, 1)
        old_beta_matrix = beta_ref_re * mask_ref_re

        old_gamma_matrix = old_gamma_matrix.view(3, 1, -1)
        old_beta_matrix = old_beta_matrix.view(3, 1, -1)

        old_gamma_matrix = old_gamma_matrix.permute(0, 2, 1)
        old_beta_matrix = old_beta_matrix.permute(0, 2, 1)
        new_gamma_matrix = torch.bmm(attention_map.to_dense(), old_gamma_matrix)
        new_beta_matrix = torch.bmm(attention_map.to_dense(), old_beta_matrix)
        gamma = new_gamma_matrix.view(-1, 1, width, height)  # (3, c, h, w)
        beta = new_beta_matrix.view(-1, 1, width, height)

        gamma = (gamma[0] + gamma[1] + gamma[2]).unsqueeze(0)  # (c, h, w) combine the three parts
        beta = (beta[0] + beta[1] + beta[2]).unsqueeze(0)
        return gamma, beta

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章