論文:DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better
Github:https://github.com/TAMU-VITA/DeblurGANv2
https://github.com/KupynOrest/DeblurGANv2
ICCV 2019
論文提出了DeblurGAN的改進版,DeblurGAN-v2,在efficiency, quality, flexibility 三方面都取得了state-of-the-art 的效果。
主要貢獻:
Framework Level:
對於生成器,爲了更好的保準生成質量,論文首次提出採用Feature Pyramid Network (FPN) 結構進行特徵融合。對於判別器部分,採用帶有最小開方損失(least-square loss )的相對判別器(relativistic discriminator),並且分別結合了全局(global (image) )和局部(local (patch) )2個尺度的判別loss。
Backbone Level:
論文采用了3種骨架網絡,分別爲Inception-ResNet-v2,MobileNet,MobileNet-DSC。Inception-ResNet-v2具有最好的精度,MobileNet和MobileNet-DSC具有更快的速度。
Experiment Level:
在3個指標PSNR, SSIM, perceptual quality 都取得了很好的結果。基於MobileNet-DSC 的DeblurGAN-v2比DeblurGAN快了11倍,並且只有4M大小。
網絡結構:
生成器基本結構爲FPN結構,分別獲取5個分支的特徵輸出,基於上採樣操作進行融合。最後再加入原圖的shortcut分支,得到最終的輸出。
輸入圖片歸一化到了[-1,1],輸出圖片也經過tanh函數歸一化到[-1,1]。
損失函數Loss:
傳統GAN的損失函數:
Least Squares GANs(LSGAN)的損失函數:
該損失有助於使得訓練過程更加平穩,高效。
判別器RaGAN-LS loss :
該loss是在LSGAN loss的基礎上,進行的改進。
生成器整體loss:
其中,Lp表示mean-square-error (MSE)
Lx表示感知loss,表示內容的損失
Ladv表示全局和局部的損失,全局表示整個圖片的損失,局部類比於PatchGAN,表示將整個圖片分塊爲一個一個的70*70的局部圖片的損失。
訓練集:
GoPro :3214 blurry/clear 圖片對,其中2103作訓練,1111做測試。
DVD :6708 blurry/clear 圖片對
NFS :75個視頻
實驗結果:
訓練&測試:
本人使用的是GOPRO數據集進行的訓練。
代碼修改,config/config.yaml
files_a: &FILES_A ./datasets/GOPRO/GOPRO_3840FPS_AVG_3-21/**/*.png
fpn_inception訓練,測試:
從頭開始訓練,python3 train.py
加載預訓練模型訓練,修改,train.py,
def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.load_state_dict(torch.load("offical_models/fpn_inception.h5", map_location='cpu')['model'])
self.netG.cuda()
self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self.scheduler_D = self._get_scheduler(self.optimizer_D)
訓練loss,
Epoch 25, lr 0.0001: 100%|##################################################################################################################| 1000/1000 [07:27<00:00, 2.23it/s, loss=G_loss=-0.0117; PSNR=38.5462; SSIM=0.9783]
Validation: 100%|#############################################################################################################################################################################| 100/100 [00:36<00:00, 2.76it/s]
G_loss=-0.0147; PSNR=36.3670; SSIM=0.9769
開始測試,python3 predict.py 007952_9.png
fpn_inception的測試效果如下,模型大小234M,
fpn_mobilenet訓練,測試:
mobilenet_v2.pth.tar模型的url:http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar
修改,config/config.yaml
g_name: fpn_mobilenet
加載預訓練模型訓練,修改train.py,
def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.load_state_dict(torch.load("offical_models/fpn_mobilenet.h5", map_location='cpu')['model'])
self.netG.cuda()
self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self.scheduler_D = self._get_scheduler(self.optimizer_D)
訓練loss,
Epoch 0, lr 0.0001: 100%|####################################################################################################################| 1000/1000 [05:13<00:00, 3.19it/s, loss=G_loss=0.0194; PSNR=39.9682; SSIM=0.9801]
Validation: 100%|#############################################################################################################################################################################| 100/100 [00:36<00:00, 2.71it/s]
G_loss=0.0275; PSNR=39.7776; SSIM=0.9802
開始測試,python3 predict.py 007952_9.png
fpn_mobilenet的測試效果如下,模型大小13M,