PSGAN網絡修改

PSGAN網絡修改

之前搭建的PSGAN是半成品,attention部分只使用原圖與參考圖的feature map進行計算,沒有使用每個像素點與landmarks的相對距離計算,所以這周對網絡結構和數據處理部分進行了修改。

在這裏插入圖片描述

網絡結構部分代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

from ops.spectral_norm import spectral_norm as SpectralNorm


# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input

class ResidualBlock(nn.Module):
    """Residual Block."""

    def __init__(self, dim_in, dim_out, net_mode=None):
        super(ResidualBlock, self).__init__()
        if net_mode == 'MDNet' or (net_mode is None):
            use_affine = True
        elif net_mode == 'MANet':
            use_affine = False
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=use_affine),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=use_affine))

    def forward(self, x):
        return x + self.main(x)


class Discriminator(nn.Module):
    """Discriminator. PatchGAN."""

    def __init__(self, image_size=128, conv_dim=64, repeat_num=3, norm='SN'):
        super(Discriminator, self).__init__()

        layers = []
        if norm == 'SN':
            layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
        else:
            layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            if norm == 'SN':
                layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
            else:
                layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        # k_size = int(image_size / np.power(2, repeat_num))
        if norm == 'SN':
            layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
        else:
            layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim * 2

        self.main = nn.Sequential(*layers)
        if norm == 'SN':
            self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
        else:
            self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)

        # conv1 remain the last square size, 256*256-->30*30
        # self.conv2 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=k_size, bias=False))
        # conv2 output a single number

    def forward(self, x):
        h = self.main(x)
        out_makeup = self.conv1(h)
        return out_makeup.squeeze()


class VGG(nn.Module):
    def __init__(self, pool='max'):
        super(VGG, self).__init__()
        # vgg modules
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        if pool == 'max':
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
            self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x, out_keys):
        out = {'r11': F.relu(self.conv1_1(x))}
        out['r12'] = F.relu(self.conv1_2(out['r11']))
        out['p1'] = self.pool1(out['r12'])
        out['r21'] = F.relu(self.conv2_1(out['p1']))
        out['r22'] = F.relu(self.conv2_2(out['r21']))
        out['p2'] = self.pool2(out['r22'])
        out['r31'] = F.relu(self.conv3_1(out['p2']))
        out['r32'] = F.relu(self.conv3_2(out['r31']))
        out['r33'] = F.relu(self.conv3_3(out['r32']))
        out['r34'] = F.relu(self.conv3_4(out['r33']))
        out['p3'] = self.pool3(out['r34'])
        out['r41'] = F.relu(self.conv4_1(out['p3']))

        out['r42'] = F.relu(self.conv4_2(out['r41']))
        out['r43'] = F.relu(self.conv4_3(out['r42']))
        out['r44'] = F.relu(self.conv4_4(out['r43']))
        out['p4'] = self.pool4(out['r44'])
        out['r51'] = F.relu(self.conv5_1(out['p4']))
        out['r52'] = F.relu(self.conv5_2(out['r51']))
        out['r53'] = F.relu(self.conv5_3(out['r52']))
        out['r54'] = F.relu(self.conv5_4(out['r53']))
        out['p5'] = self.pool5(out['r54'])

        return [out[key] for key in out_keys]


# Makeup Apply Network(MANet)
class Generator(nn.Module):
    """Generator. Encoder-Decoder Architecture."""

    def __init__(self, conv_dim=64):
        super(Generator, self).__init__()

        encoder_layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),
                          nn.InstanceNorm2d(conv_dim, affine=False), nn.ReLU(inplace=True)]

        # Down-Sampling
        curr_dim = conv_dim
        for i in range(2):
            encoder_layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))
            encoder_layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=False))
            encoder_layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck
        for i in range(3):
            encoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='MANet'))

        decoder_layers = []
        for i in range(3):
            decoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='MANet'))

        # Up-Sampling
        for i in range(2):
            decoder_layers.append(
                nn.ConvTranspose2d(curr_dim, curr_dim // 2, kernel_size=4, stride=2, padding=1, bias=False))
            decoder_layers.append(nn.InstanceNorm2d(curr_dim // 2, affine=True))
            decoder_layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        decoder_layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        decoder_layers.append(nn.Tanh())

        self.encoder = nn.Sequential(*encoder_layers)
        self.decoder = nn.Sequential(*decoder_layers)
        self.MDNet = MDNet()
        self.AMM = AMM()

    def forward(self, source_image, mask_source, rel_pos_source, reference_image, mask_ref, rel_pos_ref):
        fm_source = self.encoder(source_image)
        fm_reference = self.MDNet(reference_image)
        morphed_fm = self.AMM(fm_source, fm_reference, mask_source, mask_ref, rel_pos_source, rel_pos_ref)
        result = self.decoder(morphed_fm)
        return result


class MDNet(nn.Module):
    """Generator. Encoder-Decoder Architecture."""

    # MDNet is similar to the encoder of StarGAN
    def __init__(self, conv_dim=64):
        super(MDNet, self).__init__()

        layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),
                  nn.InstanceNorm2d(conv_dim, affine=True), nn.ReLU(inplace=True)]

        # Down-Sampling
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck
        for i in range(3):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='MDNet'))
        self.main = nn.Sequential(*layers)

    def forward(self, reference_image):
        fm_reference = self.main(reference_image)
        return fm_reference


# AMM參考 PSGAN 官方代碼進行了修改
class AMM(nn.Module):
    """Attentive Makeup Morphing module"""

    def __init__(self):
        super(AMM, self).__init__()
        self.visual_feature_weight = 0.01
        self.lambda_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)
        self.beta_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    @staticmethod
    def get_attention_map(mask_source, mask_ref, fm_source, fm_reference, rel_pos_source, rel_pos_ref):
        HW = 64 * 64
        batch_size = 3

        # get 3 part fea using mask
        channels = fm_reference.shape[1]

        mask_source_re = F.interpolate(mask_source, size=64).repeat(1, channels, 1, 1)  # (3, c, h, w)
        fm_source = fm_source.repeat(3, 1, 1, 1)  # (3, c, h, w)
        # 計算 Attention 時 we only consider the pixels belonging to same facial region.
        fm_source = fm_source * mask_source_re  # (3, c, h, w) 3 stands for 3 parts

        mask_ref_re = F.interpolate(mask_ref, size=64).repeat(1, channels, 1, 1)
        fm_reference = fm_reference.repeat(3, 1, 1, 1)
        fm_reference = fm_reference * mask_ref_re

        theta_input = torch.cat((fm_source * 0.01, rel_pos_source), dim=1)
        phi_input = torch.cat((fm_reference * 0.01, rel_pos_ref), dim=1)

        theta_target = theta_input.view(batch_size, -1, HW)  # (N, C+136, H*W)
        theta_target = theta_target.permute(0, 2, 1)  # (N, H*W, C+136)

        phi_source = phi_input.view(batch_size, -1, HW)  # (N, C+136, H*W)

        weight = torch.bmm(theta_target, phi_source)  # (3, HW, HW)
        weight = weight.cpu()
        weight_ind = torch.LongTensor(weight.detach().numpy().nonzero())
        weight = weight.cuda()
        weight_ind = weight_ind.cuda()
        weight *= 200  # hyper parameters for visual feature
        weight = F.softmax(weight, dim=-1)
        weight = weight[weight_ind[0], weight_ind[1], weight_ind[2]]
        # 那最後爲什麼不合成一個1*HW*HW的weight啊?
        return torch.sparse.FloatTensor(weight_ind, weight, torch.Size([3, HW, HW]))

    @staticmethod
    def atten_feature(mask_ref, attention_map, old_gamma_matrix, old_beta_matrix):
        # 論文中有說gamma和beta的想法源於style transfer,但不是general style transfer,所以這裏要用mask計算每個facial region的style
        batch_size, channels, width, height = old_gamma_matrix.size()
        # channels = gamma_ref.shape[1]

        mask_ref_re = F.interpolate(mask_ref, size=old_gamma_matrix.shape[2:]).repeat(1, channels, 1, 1)
        gamma_ref_re = old_gamma_matrix.repeat(3, 1, 1, 1)
        old_gamma_matrix = gamma_ref_re * mask_ref_re  # (3, c, h, w)
        print('old_gamma_matrix shape1: ', old_gamma_matrix.shape)
        beta_ref_re = old_beta_matrix.repeat(3, 1, 1, 1)
        old_beta_matrix = beta_ref_re * mask_ref_re

        old_gamma_matrix = old_gamma_matrix.view(3, 1, -1)
        print('old_gamma_matrix shape2: ', old_gamma_matrix.shape)
        old_beta_matrix = old_beta_matrix.view(3, 1, -1)

        old_gamma_matrix = old_gamma_matrix.permute(0, 2, 1)
        old_beta_matrix = old_beta_matrix.permute(0, 2, 1)
        print('old_gamma_matrix shape3: ', old_gamma_matrix.shape)
        print('attention_map.to_dense() shape: ', attention_map.to_dense().shape)
        new_gamma_matrix = torch.bmm(attention_map.to_dense(), old_gamma_matrix)
        new_beta_matrix = torch.bmm(attention_map.to_dense(), old_beta_matrix)
        gamma = new_gamma_matrix.view(-1, 1, width, height)  # (3, c, h, w)
        beta = new_beta_matrix.view(-1, 1, width, height)

        gamma = (gamma[0] + gamma[1] + gamma[2]).unsqueeze(0)  # (c, h, w) combine the three parts
        beta = (beta[0] + beta[1] + beta[2]).unsqueeze(0)
        return gamma, beta

    def forward(self, fm_source, fm_reference, mask_source, mask_ref, rel_pos_source, rel_pos_ref):
        # batch_size, channels, width, height = fm_reference.size()

        old_gamma_matrix = self.lambda_matrix_conv(fm_reference)
        old_beta_matrix = self.beta_matrix_conv(fm_reference)

        attention_map = self.get_attention_map(mask_source, mask_ref, fm_source, fm_reference, rel_pos_source,
                                               rel_pos_ref)
        gamma, beta = self.atten_feature(mask_ref, attention_map, old_gamma_matrix, old_beta_matrix)

        # 對feature_map_source進行修改
        morphed_fm_source = fm_source * (1 + gamma) + beta

        return morphed_fm_source

同時需要修改數據讀入makeup_utils部分代碼:

import torch.nn.functional as F
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms

import faceutils as futils
from ops.histogram_matching import *

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])


def ToTensor(pic):
    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float()
    else:
        return img


def copy_area(tar, src, lms):
    rect = [int(min(lms[:, 1])) - preprocess_image.eye_margin,
            int(min(lms[:, 0])) - preprocess_image.eye_margin,
            int(max(lms[:, 1])) + preprocess_image.eye_margin + 1,
            int(max(lms[:, 0])) + preprocess_image.eye_margin + 1]
    tar[:, :, rect[1]:rect[3], rect[0]:rect[2]] = \
        src[:, :, rect[1]:rect[3], rect[0]:rect[2]]


def to_var(x, requires_grad=True):
    if requires_grad:
        return Variable(x).float()
    else:
        return Variable(x, requires_grad=requires_grad).float()


def preprocess_image(image: Image):
    face = futils.dlib.detect(image)

    assert face, "no faces detected"

    # face[0]是第一個人臉,給定圖片中只能有一個人臉
    face = face[0]
    image, face = futils.dlib.crop(image, face)

    # detect landmark
    lms = futils.dlib.landmarks(image, face) * 256 / image.width
    lms = lms.round()
    lms_eye_left = lms[42:48]
    lms_eye_right = lms[36:42]
    lms = lms.transpose((1, 0)).reshape(-1, 1, 1)  # transpose to (y-x)
    lms = np.tile(lms, (1, 256, 256))  # (136, h, w)

    # calculate relative position for each pixel
    fix = np.zeros((256, 256, 68 * 2))
    for i in range(256):  # row (y) h
        for j in range(256):  # column (x) w
            fix[i, j, :68] = i
            fix[i, j, 68:] = j
    fix = fix.transpose((2, 0, 1))  # (136, h, w)
    diff = to_var(torch.Tensor(fix - lms).unsqueeze(0), requires_grad=False)

    # obtain face parsing result
    image = image.resize((512, 512), Image.ANTIALIAS)
    mask = futils.mask.mask(image).resize((256, 256), Image.ANTIALIAS)
    mask = to_var(ToTensor(mask).unsqueeze(0), requires_grad=False)
    mask_lip = (mask == 7).float() + (mask == 9).float()
    mask_face = (mask == 1).float() + (mask == 6).float()

    # 需要摳出 mask_eye
    mask_eyes = torch.zeros_like(mask)
    copy_area(mask_eyes, mask_face, lms_eye_left)
    copy_area(mask_eyes, mask_face, lms_eye_right)
    mask_eyes = to_var(mask_eyes, requires_grad=False)

    mask_list = [mask_lip, mask_face, mask_eyes]
    mask_aug = torch.cat(mask_list, 0)  # (3, 1, h, w)
    # 根據給定 size 或 scale_factor,上採樣或下采樣輸入數據input
    mask_re = F.interpolate(mask_aug, size=preprocess_image.diff_size).repeat(1, diff.shape[1], 1,
                                                                              1)  # (3, 136, 64, 64)
    diff_re = F.interpolate(diff, size=preprocess_image.diff_size).repeat(3, 1, 1, 1)  # (3, 136, 64, 64)
    # 這就是論文裏計算attention時要求同一個facial region
    diff_re = diff_re * mask_re  # (3, 136, 64, 64)
    # dim=1,求出的norm就是(3, 1, 64, 64),也就是relative position的範數值
    norm = torch.norm(diff_re, dim=1, keepdim=True).repeat(1, diff_re.shape[1], 1, 1)
    # torch.where()函數的作用是按照一定的規則合併兩個tensor類型
    norm = torch.where(norm == 0, torch.tensor(1e10), norm)
    diff_re /= norm

    image = image.resize((256, 256), Image.ANTIALIAS)
    real = to_var(transform(image).unsqueeze(0))
    return [real, mask_aug, diff_re]


def preprocess_train_image(image: Image, mask, diff_re):
    real = transform(image).unsqueeze(0)
    mask_aug = mask
    diff_re = diff_re

    return [real, mask_aug, diff_re]


# parameter of eye transfer
preprocess_image.eye_margin = 16
# down sample size
preprocess_image.diff_size = (64, 64)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章