GAN
每一種著名的代碼都值得研究,覺得過於簡單,也可能是因爲不夠了解又或者產生了錯誤的見解。並不是能用就可以結束對事物的研究發現問題也許會產生更好的結果。
Generator
從隨機數組中產生圖片,利用全連接產生有序數組,可以轉換爲圖片
以下是代碼
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
Discriminator
判別器,用於判斷真假圖片,訓練過程中,真假圖片以1:1輸入。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
其他
損失函數使用了二分類交叉熵(binary cross entropy)
torch.nn.BCELoss()
網絡搭建使用了torch.nn.Sequential()
訓練過程中D&G一起訓練
由於全連接網絡層數較淺,激活函數選擇相對隨意
但是
激活函數的制定需要根據輸出的性質進行選擇,若最後一層的激活函數選擇過於不合理網絡可能不收斂,此時的效果比無激活函數更差
網絡層數較深時,儘量不要選擇sigmoid作爲激活函數,易出現梯度消失
在訓練D時,對生成的圖片進行了反向傳播阻斷detach函數。
圖像數據的產生與還原
在此代碼中,Generator最後一層激活函數爲tanh,因此生成的數組值域爲[-1,1]
對於輸入的圖像則使用了歸一化(可能翻譯爲規範化會好一些?Normalization):
對於歸一化,正則化等操作,有很多博客的解釋可能存在問題,如有代碼建議直接查閱代碼內置解釋(使用python的help函數)
(Image-0.5)/0.5
在訓練過程中二者值域相等,可以正常訓練
還原時採用PyTorch內置函數
from torchvision.utils import save_image
對輸入的數組進行了二次normalize,此處normalize無輸入參數
應對於[-1,1]之間的數字還原至[0,1],然後再還原爲[0,255]整形數字
輸入與輸出在數據轉換過程中保持了一致,同時需要注意,二者的圖片均非二值圖片,有灰色的像素點存在。