一文掌握圖像超分辨率重建(算法原理、Pytorch實現)——含完整代碼和數據

目錄

一.  圖像超分辨率重建概述

1. 概念

2. 應用領域

3. 研究進展

3.1 傳統超分辨率重建算法

3.2 基於深度學習的超分辨率重建算法

二.  SRResNet算法原理和Pytorch實現

1. 超分重建基本處理流程

2. 構建深度網絡模型提高超分重建性能

3.  基於子像素卷積放大圖像尺寸

4.  SRResNet結構剖析

5. Pytorch實現

5.1 運行環境

5.2 訓練

5.3 評估

5.4 測試

​  

三.  SRGAN算法原理和Pytorch實現

1. 生成對抗網絡(GAN)

2. 感知損失

3. SRGAN結構剖析

4. Pytorch實現

4.1 訓練

4.2 評估

4.3 測試

四. 總結

 參考文獻


一.  圖像超分辨率重建概述

1. 概念

圖像分辨率是一組用於評估圖像中蘊含細節信息豐富程度的性能參數,包括時間分辨率、空間分辨率及色階分辨率等,體現了成像系統實際所能反映物體細節信息的能力。相較於低分辨率圖像,高分辨率圖像通常包含更大的像素密度、更豐富的紋理細節及更高的可信賴度。但在實際上情況中,受採集設備與環境、網絡傳輸介質與帶寬、圖像退化模型本身等諸多因素的約束,我們通常並不能直接得到具有邊緣銳化、無成塊模糊的理想高分辨率圖像。提升圖像分辨率的最直接的做法是對採集系統中的光學硬件進行改進,但是由於製造工藝難以大幅改進並且製造成本十分高昂,因此物理上解決圖像低分辨率問題往往代價太大。由此,從軟件和算法的角度着手,實現圖像超分辨率重建的技術成爲了圖像處理和計算機視覺等多個領域的熱點研究課題。

圖像的超分辨率重建技術指的是將給定的低分辨率圖像通過特定的算法恢復成相應的高分辨率圖像。具體來說,圖像超分辨率重建技術指的是利用數字圖像處理、計算機視覺等領域的相關知識,藉由特定的算法和處理流程,從給定的低分辨率圖像中重建出高分辨率圖像的過程。其旨在克服或補償由於圖像採集系統或採集環境本身的限制,導致的成像圖像模糊、質量低下、感興趣區域不顯著等問題。

簡單來理解超分辨率重建就是將小尺寸圖像變爲大尺寸圖像,使圖像更加“清晰”。具體效果如下圖所示:

                                                                      圖1 圖像超分辨率重建示例

可以看到,通過特定的超分辨率重建算法,使得原本模糊的圖像變得清晰了。讀者可能會疑惑,直接對低分辨率圖像進行“拉伸”不就可以了嗎?答案是可以的,但是效果並不好。傳統的“拉伸”型算法主要採用近鄰搜索等方式,即對低分辨率圖像中的每個像素採用近鄰查找或近鄰插值的方式進行重建,這種手工設定的方式只考慮了局部並不能滿足每個像素的特殊情況,難以恢復出低分辨率圖像原本的細節信息。因此,一系列有效的超分辨率重建算法開始陸續被研究學者提出,重建能力不斷加強,直至今日,依託深度學習技術,圖像的超分辨率重建已經取得了非凡的成績,在效果上愈發真實和清晰。

2. 應用領域

1955年,Toraldo di Francia在光學成像領域首次明確定義了超分辨率這一概念,主要是指利用光學相關的知識,恢復出衍射極限以外的數據信息的過程。1964年左右,Harris和Goodman則首次提出了圖像超分辨率這一概念,主要是指利用外推頻譜的方法合成出細節信息更豐富的單幀圖像的過程。1984 年,在前人的基礎上,Tsai和 Huang 等首次提出使用多幀低分辨率圖像重建出高分辨率圖像的方法後, 超分辨率重建技術開始受到了學術界和工業界廣泛的關注和研究。

圖像超分辨率重建技術在多個領域都有着廣泛的應用範圍和研究意義。主要包括:

(1) 圖像壓縮領域

在視頻會議等實時性要求較高的場合,可以在傳輸前預先對圖片進行壓縮,等待傳輸完畢,再由接收端解碼後通過超分辨率重建技術復原出原始圖像序列,極大減少存儲所需的空間及傳輸所需的帶寬。

(2) 醫學成像領域

對醫學圖像進行超分辨率重建,可以在不增加高分辨率成像技術成本的基礎上,降低對成像環境的要求,通過復原出的清晰醫學影像,實現對病變細胞的精準探測,有助於醫生對患者病情做出更好的診斷。

(3) 遙感成像領域

高分辨率遙感衛星的研製具有耗時長、價格高、流程複雜等特點,由此研究者將圖像超分辨率重建技術引入了該領域,試圖解決高分辨率的遙感成像難以獲取這一挑戰,使得能夠在不改變探測系統本身的前提下提高觀測圖像的分辨率。

(4) 公共安防領域

公共場合的監控設備採集到的視頻往往受到天氣、距離等因素的影響,存在圖像模糊、分辨率低等問題。通過對採集到的視頻進行超分辨率重建,可以爲辦案人員恢復出車牌號碼、清晰人臉等重要信息,爲案件偵破提供必要線索。

(5) 視頻感知領域

通過圖像超分辨率重建技術,可以起到增強視頻畫質、改善視頻的質量,提升用戶的視覺體驗的作用。

3. 研究進展

按照時間和效果進行分類,可以將超分辨率重建算法分爲傳統算法和深度學習算法兩類。

3.1 傳統超分辨率重建算法

傳統的超分辨率重建算法主要依靠基本的數字圖像處理技術進行重建,常見的有如下幾類:

(1) 基於插值的超分辨率重建

基於插值的方法將圖像上每個像素都看做是圖像平面上的一個點,那麼對超分辨率圖像的估計可以看做是利用已知的像素信息爲平面上未知的像素信息進行擬合的過程,這通常由一個預定義的變換函數或者插值核來完成。基於插值的方法計算簡單、易於理解,但是也存在着一些明顯的缺陷。

首先,它假設像素灰度值的變化是一個連續的、平滑的過程,但實際上這種假設並不完全成立。其次,在重建過程中,僅根據一個事先定義的轉換函數來計算超分辨率圖像,不考慮圖像的降質退化模型,往往會導致復原出的圖像出現模糊、鋸齒等現象。常見的基於插值的方法包括最近鄰插值法、雙線性插值法和雙立方插值法等。

(2) 基於退化模型的超分辨率重建

此類方法從圖像的降質退化模型出發,假定高分辨率圖像是經過了適當的運動變換、模糊及噪聲纔得到低分辨率圖像。這種方法通過提取低分辨率圖像中的關鍵信息,並結合對未知的超分辨率圖像的先驗知識來約束超分辨率圖像的生成。常見的方法包括迭代反投影法、凸集投影法和最大後驗概率法等。

(3) 基於學習的超分辨率重建

基於學習的方法則是利用大量的訓練數據,從中學習低分辨率圖像和高分辨率圖像之間某種對應關係,然後根據學習到的映射關係來預測低分辨率圖像所對應的高分辨率圖像,從而實現圖像的超分辨率重建過程。常見的基於學習的方法包括流形學習、稀疏編碼方法。

3.2 基於深度學習的超分辨率重建算法

機器學習是人工智能的一個重要分支,而深度學習則是機器學習中最主要的一個算法,其旨在通過多層非線性變換,提取數據的高層抽象特徵,學習數據潛在的分佈規律,從而獲取對新數據做出合理的判斷或者預測的能力。隨着人工智能和計算機硬件的不斷髮展,Hinton等人在2006年提出了深度學習這一概念,其旨在利用多層非線性變換提取數據的高層抽象特徵。憑藉着強大的擬合能力,深度學習開始在各個領域嶄露頭角,特別是在圖像與視覺領域,卷積神經網絡大放異,這也使得越來越多的研究者開始嘗試將深度學習引入到超分辨率重建領域。

2014年,Dong等人首次將深度學習應用到圖像超分辨率重建領域,他們使用一個三層的卷積神經網絡學習低分辨率圖像與高分辨率圖像之間映射關係,自此,在超分辨率重建率領域掀起了深度學習的浪潮,他們的設計的網絡模型命名爲SRCNN(Super-Resolution Convolutional Neural Network)。

SRCNN採用了插值的方式先將低分辨率圖像進行放大,再通過模型進行復原。Shi等人則認爲這種預先採用近鄰插值的方式本身已經影響了性能,如果從源頭出發,應該從樣本中去學習如何進行放大,他們基於這個原理提出了ESPCN (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network)算法。該算法在將低分辨率圖像送入神經網絡之前,無需對給定的低分辨率圖像進行一個上採樣過程,而是引入一個亞像素卷積層(Sub-pixel convolution layer),來間接實現圖像的放大過程。這種做法極大降低了SRCNN的計算量,提高了重建效率。

這裏需要注意到,不管是SRCNN還是ESPCN,它們均使用了MSE作爲目標函數來訓練模型。2017年,Christian Ledig等人從照片感知角度出發,通過對抗網絡來進行超分重建(論文題目:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network)。他們認爲,大部分深度學習超分算法採用的MSE損失函數會導致重建的圖像過於平滑,缺乏感官上的照片真實感。他們改用生成對抗網絡(Generative Adversarial Networks, GAN)來進行重建,並且定義了新的感知目標函數,算法被命名爲SRGAN,由一個生成器和一個判別器組成。生成器負責合成高分辨率圖像,判別器用於判斷給定的圖像是來自生成器還是真實樣本。通過一個博弈的對抗過程,使得生成器能夠將給定的低分辨率圖像重建爲高分辨率圖像。在SRGAN這篇論文中,作者同時提出了一個對比算法,名爲SRResNet。SRResNet依然採用了MSE作爲最終的損失函數,與以往不同的是,SRResNet採用了足夠深的殘差卷積網絡模型,相比於其它的殘差學習重建算法,SRResNet本身也能夠取得較好的效果。

由於SRGAN這篇論文同時提出了兩種當前主流模式的深度學習超分重建算法,因此,接下來將以SRGAN這篇論文爲主線,依次講解SRResNet和SRGAN算法實現原理,並採用Pytorch深度學習框架完成上述兩個算法的復現。

二.  SRResNet算法原理和Pytorch實現

1. 超分重建基本處理流程

最早的採用深度學習進行超分重建的算法是SRCNN算法,其原理很簡單,對於輸入的一張低分辨率圖像,SRCNN首先使用雙立方插值將其放大至目標尺寸,然後利用一個三層的卷積神經網絡去擬合低分辨率圖像與高分辨率圖像之間的非線性映射,最後將網絡輸出的結果作爲重建後的高分辨率圖像。儘管原理簡單,但是依託深度學習模型以及大樣本數據的學習,在性能上超過了當時一衆傳統的圖像處理算法,開啓了深度學習在超分辨率領域的研究征程。SRCNN的網絡結構如圖2所示。

                                                                                                             圖2 SRCNN的網絡結構

其中f_1f_2f_3分別表示1、2、3層卷積對應的核大小。

SRCNN作爲早期開創性的研究論文,也爲後面的工作奠定了處理超分問題的基本流程:

(1) 尋找大量真實場景下圖像樣本;

(2) 對每張圖像進行下采樣處理降低圖像分辨率,一般有2倍下采樣、3倍下采樣、4倍下采樣等。如果是2倍下采樣,則圖像長寬均變成原來的1/2.。下采樣前的圖像作爲高分辨率圖像H,下采樣後的圖像作爲低分辨率圖像L,L和H構成一個有效的圖像對用於後期模型訓練;

(3) 訓練模型時,對低分辨率圖像L進行放大還原爲高分辨率圖像SR,然後與原始的高分辨率圖像H進行比較,其差異用來調整模型的參數,通過迭代訓練,使得差異最小。實際情況下,研究學者提出了多種損失函數用來定義這種差異,不同的定義方式也會直接影響最終的重建效果;

(4) 訓練完的模型可以用來對新的低分辨率圖像進行重建,得到高分辨率圖像。

從實際操作上來看,整個超分重建分爲兩步:圖像放大和修復。所謂放大就是採用某種方式(SRCNN採用了插值上採樣)將圖像放大到指定倍數,然後再根據圖像修復原理,將放大後的圖像映射爲目標圖像。超分辨率重建不僅能夠放大圖像尺寸,在某種意義上具備了圖像修復的作用,可以在一定程度上削弱圖像中的噪聲、模糊等。因此,超分辨率重建的很多算法也被學者遷移到圖像修復領域中,完成一些諸如jpep壓縮去燥、去模糊等任務。

                                                                                           圖3 簡化版超分重建處理流程

簡化版的超分重建處理流程如圖3所示,當然,圖像放大和修復兩個步驟的順序可以任意互換。

2. 構建深度網絡模型提高超分重建性能

SRCNN只採用了3個卷積層來實現超分重建,有文獻指出如果採用更深的網絡結構模型,那麼可以重建出更高質量的圖像,因爲更深的網絡模型可以抽取出更高級的圖像特徵,這種深層模型對圖像可以更好的進行表達。在SRCNN之後,有不少研究人員嘗試加深網絡結構以期取得更佳的重建性能,但是越深的模型越不能很好的收斂,無法得到期望的結果。部分研究學者通過遷移學習來逐步的增加模型深度,但這種方式加深程度有限。因此,亟需一種有效的模型,使得構建深層網絡模型變得容易並且有效。這個問題直到2015年由何凱明團隊提出ResNet網絡才得以有效解決。

ResNet中文名字叫作深度殘差網絡,主要作用是圖像分類。現在在圖像分割、目標檢測等領域都有很廣泛的運用。ResNet在傳統卷積神經網絡中加入了殘差學習(residual learning),解決了深層網絡中梯度彌散和精度下降(訓練集)的問題,使網絡能夠越來越深,既保證了精度,又控制了速度。

ResNet可以直觀的來理解其背後的意義。以往的神經網絡模型每一層學習的是一個 y = f(x) 的映射,可以想象的到,隨着層數不斷加深,每個函數映射出來的y誤差逐漸累計,誤差越來越大,梯度在反向傳播的過程中越來越發散。這時候,如果改變一下每層的映射關係,改爲 y = f(x) + x,也就是在每層的結束加上原始輸入,此時輸入是x,輸出是f(x)+x,那麼自然的f(x)趨向於0,或者說f(x)是一個相對較小的值,這樣,即便層數不斷加大,這個誤差f(x)依然控制在一個較小值,整個模型訓練時不容易發散。

                                                                                         圖4 殘差網絡原理圖

上圖爲殘差網絡的原理圖,可以看到一根線直接跨越兩層網絡(跳鏈),將原始數據x帶入到了輸出中,此時F(x)預測的是一個差值。有了殘差學習這種強大的網絡結構,就可以按照SRCNN的思路構建用於超分重建的深度神經網絡。SRResNet算法主幹部分就採用了這種網絡結構,如下圖所示:

                                                                             圖5 超分重建深度殘差模塊

上述模型採用了多個深度殘差模塊進行圖像的特徵抽取,多次運用跳鏈技術將輸入連接到網絡輸出,這種結構能夠保證整個網絡的穩定性。由於採用了深度模型,相比淺層模型能夠更有效的挖掘圖像特徵,在性能上可以超越淺層模型算法(SRResNet使用了16個殘差模塊)。注意到,上述模型每層僅僅改變了圖像的通道數,並沒有改變圖像的尺寸大小,從這個意義上來說這個網絡可以認爲是前面提到的修復模型。下面會介紹如何在這個模型基礎上再增加一個子模塊用來放大圖像,從而構建一個完整的超分重建模型。

3.  基於子像素卷積放大圖像尺寸

子像素卷積(Sub-pixel convolution)是一種巧妙的圖像及特徵圖放大方法,又叫做pixel shuffle(像素清洗)。在深度學習超分辨率重建中,常見的擴尺度方法有直接上採樣,雙線性插值,反捲積等等。ESPCN算法中提出了一種超分辨率擴尺度方法,即爲子像素卷積方法,該方法後續也被應用在了SRResNet和SRGAN算法中。因此,這裏需要先介紹子像素卷積的原理及實現方式。

採用CNN對特徵圖進行放大一般會採用deconvolution等方法,這種方法通常會帶入過多人工因素,而子像素卷積會大大降低這個風險。因爲子像素卷積放大使用的參數是需要學習的,相比那些手工設定的方式,這種通過樣本學習的方式其放大性能更加準確。

具體實現原理如下圖所示:

                                                                                             圖6 子像素卷積示意圖

上圖很直觀得表達了子像素卷積的流程。假設,如果想對原圖放大3倍,那麼需要生成出3^2=9個同等大小的特徵圖,也就是通道數擴充了9倍(這個通過普通的卷積操作即可實現)。然後將九個同等大小的特徵圖拼成一個放大3倍的大圖,這就是子像素卷積操作了。

實現時先將原始特徵圖通過卷積擴展其通道數,如果是想放大4倍,那麼就需要將通道數擴展爲原來的16倍。特徵圖做完卷積後再按照特定的格式進行排列,即可得到一張大圖,這就是所謂的像素清洗。通過像素清洗,特徵的通道數重新恢復爲原來輸入時的大小,但是每個特徵圖的尺寸變大了。這裏注意到每個像素的擴展方式由對應的卷積來決定,此時卷積的參數是需要學習的,因此,相比於手工設計的放大方式,這種基於學習的放大方式能夠更好的去擬合像素之間的關係。

SRResNet模型也利用子像素卷積來放大圖像,具體的,在圖5所示模型後面添加兩個子像素卷積模塊,每個子像素卷積模塊使得輸入圖像放大2倍,因此這個模型最終可以將圖像放大4倍,如下圖所示:

                                                                                     圖7 SRResNet子像素卷積模塊

4.  SRResNet結構剖析

SRResNet使用深度殘差網絡來構建超分重建模型,主要包含兩部分:深度殘差模型、子像素卷積模型。深度殘差模型用來進行高效的特徵提取,可以在一定程度上削弱圖像噪點。子像素卷積模型主要用來放大圖像尺寸。完整的SRResNet網絡結果如下圖所示:

                                                                                          圖 8 SRResNet網絡結構

上圖中,k表示卷積核大小,n表示輸出通道數,s表示步長。除了深度殘差模塊和子像素卷積模塊以外,在整個模型輸入和輸出部分均添加了一個卷積模塊用於數據調整和增強。

需要注意的是,SRResNet模型使用MSE作爲目標函數,也就是通過模型還原出來的高分辨率圖像與原始高分辨率圖像的均方誤差,公式如下:

                                                                           L=\frac{1}{n}\sum_{i=1}^{n}\|F(X_i;\Theta))-Y_i\|^2

MSE也是目前大部分超分重建算法採用的目標函數。後面我們會看到,使用該目標函數重建的超分圖像並不能很好的符合人眼主觀感受,SRGAN算法正是基於此進行的改進。

5. Pytorch實現

本節將從源碼出發,完成SRResNet算法的建模、訓練和推理。本文基於深度學習框架Pytorch來完成所有的編碼工作,讀者在閱讀本文代碼前需要熟悉Pytorch基本操作命令。

所有代碼和數據可以從百度雲上進行下載https://pan.baidu.com/s/1yUCK8JMmMRjDgtwwUt7z2w

提取碼:jkpy

該工程比較大,主要是包含了用於訓練的COCO2014數據集。提供這樣一個完整的工程包是爲了方便讀者只需要下載和解壓就可以直接運行,而不需要再去額外的尋找數據集和測試集。代碼裏也提供了已經訓練好的.pth模型文件。

5.1 運行環境

(1) 基本配置

本文使用Python語言進行代碼編寫,Python版本爲3.6

採用Pytorch深度學習框架進行算法建模,對應的版本爲torch1.4.0。Pytorch的完整安裝教程請參考另一篇博客。Pytorch的兼容性較好,1.0版本以後代碼都可以正常運行。

操作系統爲Windows 10,使用2塊 GTX 1080TI顯卡進行加速運算。

IDE爲VS Code。

(2) 安裝scikit-image

按照下述命令進行安裝,該包主要用來提供PSNR和SSIM的計算。PSNR和SSIM是超分重建中經常會使用的兩個評價指標。

pip install scikit-image==0.16.2

(3) 可視化結果

爲了能夠在模型的訓練過程中方便的查看運行結果,推薦使用tensorboard來實現。由於tensorboard是tensorflow推出來的,因此首先需要安裝tensorflow,但是此處沒必要再安裝GPU版的tensorflow,只需要直接安裝CPU版本的即可,安裝命令如下:

pip install tensorflow

本文安裝的tensorflow版本爲2.1.0。在安裝tensorflow的過程中會自動安裝tensorboard。

(4) 數據集

本文使用COCO2014數據集進行訓練,訓練時聯合使用train2014(82783張圖片)和val2014(40504張圖片)兩部分數據,共有123285張圖像。測試時使用Set5、Set14和BSD100數據集分別進行測試。

5.2 訓練

(1)代碼結構組織

爲了方便讀者閱讀、運行和修改代碼,本文采用比較簡單的代碼組織方式。完整結構如下圖所示:

                                                                                                圖9 代碼組織結構

項目根目錄下有8個.py文件和2個文件夾,下面對各個文件和文件夾進行簡單說明。

  • create_data_lists.py:生成數據列表,檢查數據集中的圖像文件尺寸,並將符合的圖像文件名寫入JSON文件列表供後續Pytorch調用;
  • datasets.py:用於構建數據集加載器,主要沿用Pytorch標準數據加載器格式進行封裝;
  • models.py:模型結構文件,存儲各個模型的結構定義;
  • utils.py:工具函數文件,所有項目中涉及到的一些自定義函數均放置在該文件中;
  • train_srresnet.py:用於訓練SRResNet算法;
  • train_srgan.py:用於訓練SRGAN算法;
  • eval.py:用於模型評估,主要以計算測試集的PSNR和SSIM爲主;
  • test.py:用於單張樣本測試,運用訓練好的模型爲單張圖像進行超分重建;
  • data:用於存放訓練和測試數據集以及文件列表;
  • results:用於存放運行結果,包括訓練好的模型以及單張樣本測試結果;

讀者可以下載本文代碼和數據集進行查看和運行,整個代碼運行順序如下:

  • 運行create_data_lists.py文件用於爲數據集生成文件列表;
  • 運行train_srresnet.py進行SRResNet算法訓練,訓練結束後在results文件夾中會生成checkpoint_srresnet.pth模型文件;
  • 運行eval.py文件對測試集進行評估,計算每個測試集的平均PSNR、SSIM值;
  • 運行test.py文件對results文件夾下名爲test.jpg的圖像進行超分還原,還原結果存儲在results文件夾下面;
  • 運行train_srgan.py文件進行SRGAN算法訓練,訓練結束後在results文件夾中會生成checkpoint_srgan.pth模型文件;
  • 修改並運行eval.py文件對測試集進行評估,計算每個測試集的平均PSNR、SSIM值;
  • 修改並運行test.py文件對results文件夾下名爲test.jpg的圖像進行超分還原,還原結果存儲在results文件夾下面;

(2)生成數據列表

在訓練前需要先準備好數據集,按照特定的格式生成文件列表供Pytorch的數據加載器torch.utils.data.DataLoader對圖像進行高效率並行加載。只有準確生成了數據集列表文件才能進行下面的訓練。

create_data_lists.py文件內容如下:

from utils import create_data_lists

if __name__ == '__main__':
    create_data_lists(train_folders=['./data/COCO2014/train2014',
                                     './data/COCO2014/val2014'],
                      test_folders=['./data/BSD100',
                                    './data/Set5',
                                    './data/Set14'],
                      min_size=100,
                      output_folder='./data/')

首先從utils中導入create_data_lists函數,該函數用於執行具體的JSON文件創建。在主函數部分設置好訓練集train_folders和測試集test_folders文件夾路徑,參數min_size=100用於檢查訓練集和測試集中的每張圖像分辨率,無論是圖像寬度還是高度,如果小於min_size則該圖像路徑不寫入JSON文件列表中。output_folder用於指明最後的JSON文件列表存放路徑。

create_data_lists實現方式如下:

from PIL import Image
import os
import json

def create_data_lists(train_folders, test_folders, min_size, output_folder):
    """
    創建訓練集和測試集列表文件.
        參數 train_folders: 訓練文件夾集合; 各文件夾中的圖像將被合併到一個圖片列表文件裏面
        參數 test_folders: 測試文件夾集合; 每個文件夾將形成一個圖片列表文件
        參數 min_size: 圖像寬、高的最小容忍值
        參數 output_folder: 最終生成的文件列表,json格式
    """
    print("\n正在創建文件列表... 請耐心等待.\n")
    train_images = list()
    for d in train_folders:
        for i in os.listdir(d):
            img_path = os.path.join(d, i)
            img = Image.open(img_path, mode='r')
            if img.width >= min_size and img.height >= min_size:
                train_images.append(img_path)
    print("訓練集中共有 %d 張圖像\n" % len(train_images))
    with open(os.path.join(output_folder, 'train_images.json'), 'w') as j:
        json.dump(train_images, j)

    for d in test_folders:
        test_images = list()
        test_name = d.split("/")[-1]
        for i in os.listdir(d):
            img_path = os.path.join(d, i)
            img = Image.open(img_path, mode='r')
            if img.width >= min_size and img.height >= min_size:
                test_images.append(img_path)
        print("在測試集 %s 中共有 %d 張圖像\n" %
              (test_name, len(test_images)))
        with open(os.path.join(output_folder, test_name + '_test_images.json'),'w') as j:
            json.dump(test_images, j)

    print("生成完畢。訓練集和測試集文件列表已保存在 %s 下\n" % output_folder)

運行後,在data文件夾下會產生下列文件列表(每個文件列表均存儲了用於訓練和測試的圖像路徑):

(3)訓練SRResNet

train_srresnet.py文件內容如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from models import SRResNet
from datasets import SRDataset
from utils import *


# 數據集參數
data_folder = './data/'          # 數據存放路徑
crop_size = 96      # 高分辨率圖像裁剪尺寸
scaling_factor = 4  # 放大比例

# 模型參數
large_kernel_size = 9   # 第一層卷積和最後一層卷積的核大小
small_kernel_size = 3   # 中間層卷積的核大小
n_channels = 64         # 中間層通道數
n_blocks = 16           # 殘差模塊數量

# 學習參數
checkpoint = None   # 預訓練模型路徑,如果不存在則爲None
batch_size = 400    # 批大小
start_epoch = 1     # 輪數起始位置
epochs = 130        # 迭代輪數
workers = 4         # 工作線程數
lr = 1e-4           # 學習率

# 設備參數
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 2           # 用來運行的gpu數量

cudnn.benchmark = True # 對卷積進行加速

writer = SummaryWriter() # 實時監控     使用命令 tensorboard --logdir runs  進行查看

def main():
    """
    訓練.
    """
    global checkpoint,start_epoch,writer

    # 初始化
    model = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    # 初始化優化器
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)

    # 遷移至默認設備進行訓練
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # 加載預訓練模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定製化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 開始逐輪訓練
    for epoch in range(start_epoch, epochs+1):

        model.train()  # 訓練模式:允許使用批樣本歸一化

        loss_epoch = AverageMeter()  # 統計損失函數

        n_iter = len(train_loader)

        # 按批處理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 數據移至默認設備進行訓練
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  [-1, 1]格式

            # 前向傳播
            sr_imgs = model(lr_imgs)

            # 計算損失
            loss = criterion(sr_imgs, hr_imgs)  

            # 後向傳播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 記錄損失值
            loss_epoch.update(loss.item(), lr_imgs.size(0))

            # 監控圖像變化
            if i==(n_iter-2):
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_1', make_grid(lr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_2', make_grid(sr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_3', make_grid(hr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)

            # 打印結果
            print("第 "+str(i)+ " 個batch訓練結束")
 
        # 手動釋放內存              
        del lr_imgs, hr_imgs, sr_imgs

        # 監控損失值變化
        writer.add_scalar('SRResNet/MSE_Loss', loss_epoch.val, epoch)    

        # 保存預訓練模型
        torch.save({
            'epoch': epoch,
            'model': model.module.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/checkpoint_srresnet.pth')
    
    # 訓練結束關閉監控
    writer.close()


if __name__ == '__main__':
    main()

上述代碼中已經配上詳細的註釋,讀者可以先Debug逐行運行即可。其中參數batch_size 設置爲400,GPU數量ngpu設置爲2,讀者可以根據自己的機器性能調節設置,對於一般的GPU,batch_size設置爲128可以滿足。
        需要說明的是本文使用了tensorboard來查看訓練結果,在程序運行的過程中可以再開一個終端運行下述命令:

tensorboard --logdir runs

從而打開tensorboard監視器,打開後可以在瀏覽器中訪問:http://localhost:6006/ 來監視訓練結果。

上述代碼實現時首先構建了一個SRResNet模型,該模型的定義在models.py文件中給出:

class SRResNet(nn.Module):
    """
    SRResNet模型
    """
    def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
        """
        :參數 large_kernel_size: 第一層卷積和最後一層卷積核大小
        :參數 small_kernel_size: 中間層卷積核大小
        :參數 n_channels: 中間層通道數
        :參數 n_blocks: 殘差模塊數
        :參數 scaling_factor: 放大比例
        """
        super(SRResNet, self).__init__()

        # 放大比例必須爲 2、 4 或 8
        scaling_factor = int(scaling_factor)
        assert scaling_factor in {2, 4, 8}, "放大比例必須爲 2、 4 或 8!"

        # 第一個卷積塊
        self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='PReLu')

        # 一系列殘差模塊, 每個殘差模塊包含一個跳連接
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)])

        # 第二個卷積塊
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=small_kernel_size,
                                              batch_norm=True, activation=None)

        # 放大通過子像素卷積模塊實現, 每個模塊放大兩倍
        n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
        self.subpixel_convolutional_blocks = nn.Sequential(
            *[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=n_channels, scaling_factor=2) for i
              in range(n_subpixel_convolution_blocks)])

        # 最後一個卷積模塊
        self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='Tanh')

    def forward(self, lr_imgs):
        """
        前向傳播.

        :參數 lr_imgs: 低分辨率輸入圖像集, 張量表示,大小爲 (N, 3, w, h)
        :返回: 高分辨率輸出圖像集, 張量表示, 大小爲 (N, 3, w * scaling factor, h * scaling factor)
        """
        output = self.conv_block1(lr_imgs)  # (16, 3, 24, 24)
        residual = output  # (16, 64, 24, 24)
        output = self.residual_blocks(output)  # (16, 64, 24, 24)
        output = self.conv_block2(output)  # (16, 64, 24, 24)
        output = output + residual  # (16, 64, 24, 24)
        output = self.subpixel_convolutional_blocks(output)  # (16, 64, 24 * 4, 24 * 4)
        sr_imgs = self.conv_block3(output)  # (16, 3, 24 * 4, 24 * 4)

        return sr_imgs

整個模型完全參照SRResNet的實現方式,組成方式爲:1個卷積模塊+16個殘差模塊+1個卷積模塊+2個子像素卷積模塊+1個卷積模塊。

數據的加載通過自定義的SRDataset來實現,其定義在datasets.py文件中給出:

import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from utils import ImageTransforms


class SRDataset(Dataset):
    """
    數據集加載器
    """

    def __init__(self, data_folder, split, crop_size, scaling_factor, lr_img_type, hr_img_type, test_data_name=None):
        """
        :參數 data_folder: # Json數據文件所在文件夾路徑
        :參數 split: 'train' 或者 'test'
        :參數 crop_size: 高分辨率圖像裁剪尺寸  (實際訓練時不會用原圖進行放大,而是截取原圖的一個子塊進行放大)
        :參數 scaling_factor: 放大比例
        :參數 lr_img_type: 低分辨率圖像預處理方式
        :參數 hr_img_type: 高分辨率圖像預處理方式
        :參數 test_data_name: 如果是評估階段,則需要給出具體的待評估數據集名稱,例如 "Set14"
        """

        self.data_folder = data_folder
        self.split = split.lower()
        self.crop_size = int(crop_size)
        self.scaling_factor = int(scaling_factor)
        self.lr_img_type = lr_img_type
        self.hr_img_type = hr_img_type
        self.test_data_name = test_data_name

        assert self.split in {'train', 'test'}
        if self.split == 'test' and self.test_data_name is None:
            raise ValueError("請提供測試數據集名稱!")
        assert lr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}
        assert hr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}

        # 如果是訓練,則所有圖像必須保持固定的分辨率以此保證能夠整除放大比例
        # 如果是測試,則不需要對圖像的長寬作限定
        if self.split == 'train':
            assert self.crop_size % self.scaling_factor == 0, "裁剪尺寸不能被放大比例整除!"

        # 讀取圖像路徑
        if self.split == 'train':
            with open(os.path.join(data_folder, 'train_images.json'), 'r') as j:
                self.images = json.load(j)
        else:
            with open(os.path.join(data_folder, self.test_data_name + '_test_images.json'), 'r') as j:
                self.images = json.load(j)

        # 數據處理方式
        self.transform = ImageTransforms(split=self.split,
                                         crop_size=self.crop_size,
                                         scaling_factor=self.scaling_factor,
                                         lr_img_type=self.lr_img_type,
                                         hr_img_type=self.hr_img_type)

    def __getitem__(self, i):
        """
        爲了使用PyTorch的DataLoader,必須提供該方法.

        :參數 i: 圖像檢索號
        :返回: 返回第i個低分辨率和高分辨率的圖像對
        """
        # 讀取圖像
        img = Image.open(self.images[i], mode='r')
        img = img.convert('RGB')
        if img.width <= 96 or img.height <= 96:
            print(self.images[i], img.width, img.height)
        lr_img, hr_img = self.transform(img)

        return lr_img, hr_img

    def __len__(self):
        """
        爲了使用PyTorch的DataLoader,必須提供該方法.

        :返回: 加載的圖像總數
        """
        return len(self.images)

其中需要注意的是,我們提供了4中圖像數據的變換方式:'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'。對於SRResNet數據集,我們希望輸入的圖像數據經過標準的ImageNet處理,即減去ImageNet均值併除以其方差,對於輸出將其變換至[-1,1]之間。具體的變換由utils.py文件的convert_image函數實現。

圖像加載的處理流程如下:

  • 加載一張圖像,從任意位置處裁剪96x96的子塊,將該子塊作爲原始高分辨率圖像hr_img;
  • 對hr_img進行雙線性下采樣(4倍),得到24x24的子塊,將該子塊作爲初始的低分辨率圖像lr_img;
  • 對lr_img按照ImageNet數據集方式進行預處理,將hr_img轉換至[-1,1];
  • 將lr_img和hr_img作爲一對訓練對返回;

圖像變換的實現方式在utils.py文件中的ImageTransforms類給出:

class ImageTransforms(object):
    """
    圖像變換.
    """

    def __init__(self, split, crop_size, scaling_factor, lr_img_type,
                 hr_img_type):
        """
        :參數 split: 'train' 或 'test'
        :參數 crop_size: 高分辨率圖像裁剪尺寸
        :參數 scaling_factor: 放大比例
        :參數 lr_img_type: 低分辨率圖像預處理方式
        :參數 hr_img_type: 高分辨率圖像預處理方式
        """
        self.split = split.lower()
        self.crop_size = crop_size
        self.scaling_factor = scaling_factor
        self.lr_img_type = lr_img_type
        self.hr_img_type = hr_img_type

        assert self.split in {'train', 'test'}

    def __call__(self, img):
        """
        對圖像進行裁剪和下采樣形成低分辨率圖像
        :參數 img: 由PIL庫讀取的圖像
        :返回: 特定形式的低分辨率和高分辨率圖像
        """

        # 裁剪
        if self.split == 'train':
            # 從原圖中隨機裁剪一個子塊作爲高分辨率圖像
            left = random.randint(1, img.width - self.crop_size)
            top = random.randint(1, img.height - self.crop_size)
            right = left + self.crop_size
            bottom = top + self.crop_size
            hr_img = img.crop((left, top, right, bottom))
        else:
            # 從圖像中儘可能大的裁剪出能被放大比例整除的圖像
            x_remainder = img.width % self.scaling_factor
            y_remainder = img.height % self.scaling_factor
            left = x_remainder // 2
            top = y_remainder // 2
            right = left + (img.width - x_remainder)
            bottom = top + (img.height - y_remainder)
            hr_img = img.crop((left, top, right, bottom))

        # 下采樣(雙三次差值)
        lr_img = hr_img.resize((int(hr_img.width / self.scaling_factor),
                                int(hr_img.height / self.scaling_factor)),
                               Image.BICUBIC)

        # 安全性檢查
        assert hr_img.width == lr_img.width * self.scaling_factor and hr_img.height == lr_img.height * self.scaling_factor

        # 轉換圖像
        lr_img = convert_image(lr_img, source='pil', target=self.lr_img_type)
        hr_img = convert_image(hr_img, source='pil', target=self.hr_img_type)

        return lr_img, hr_img

由於採用的是Pytorch框架,讀者可以方便的通過debug方式逐行運行和調試代碼。整體實現難度並不大,本文不再對此進行贅述。         

(4)訓練結果

訓練共用時5小時19分6秒(2塊GTX 1080Ti顯卡),訓練完成後保存的模型共17.8M。下圖展示了訓練過程中的損失函數變化。可以看到,隨着訓練的進行,損失函數逐漸開始收斂,在結束的時候基本處在收斂平穩點。

                                                                              圖10 訓練時損失函數變化

下圖展示了訓練過程中訓練數據超分重建的效果圖,依次展示epoch=1、60和130時的效果,每張圖像共三行,第一行爲低分辨率圖像,第二行爲當前模型重建出的超分圖像,第三行爲實際的真實原始清晰圖像。可以看到,隨着迭代次數的增加,超分還原的效果越來越好,到了第99個epoch的時候還原出來的圖像已經大幅削弱了塊狀噪點的影響,圖像更加的平滑和清晰。

                                                                                      圖11 epoch=1時超分重建效

                                                                                     圖12 epoch=60時超分重建效果

                                                                                     圖13 epoch=130時超分重建效果

5.3 評估

eval.py文件完整代碼如下:

from utils import *
from torch import nn
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from datasets import SRDataset
from models import SRResNet
import time

# 模型參數
large_kernel_size = 9   # 第一層卷積和最後一層卷積的核大小
small_kernel_size = 3   # 中間層卷積的核大小
n_channels = 64         # 中間層通道數
n_blocks = 16           # 殘差模塊數量
scaling_factor = 4      # 放大比例
ngpu = 2                # GP數量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':
    
    # 測試集目錄
    data_folder = "./data/"
    test_data_names = ["Set5","Set14", "BSD100"]

    # 預訓練模型
    srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加載模型SRResNet
    checkpoint = torch.load(srresnet_checkpoint)
    srresnet = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    srresnet = srresnet.to(device)
    srresnet.load_state_dict(checkpoint['model'])

    # 多GPU測試
    if torch.cuda.is_available() and ngpu > 1:
        srresnet = nn.DataParallel(srresnet, device_ids=list(range(ngpu)))
   
    srresnet.eval()
    model = srresnet

    for test_data_name in test_data_names:
        print("\n數據集 %s:\n" % test_data_name)

        # 定製化數據加載器
        test_dataset = SRDataset(data_folder,
                                split='test',
                                crop_size=0,
                                scaling_factor=4,
                                lr_img_type='imagenet-norm',
                                hr_img_type='[-1, 1]',
                                test_data_name=test_data_name)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1,
                                                pin_memory=True)

        # 記錄每個樣本 PSNR 和 SSIM值
        PSNRs = AverageMeter()
        SSIMs = AverageMeter()

        # 記錄測試時間
        start = time.time()

        with torch.no_grad():
            # 逐批樣本進行推理計算
            for i, (lr_imgs, hr_imgs) in enumerate(test_loader):
                
                # 數據移至默認設備
                lr_imgs = lr_imgs.to(device)  # (batch_size (1), 3, w / 4, h / 4), imagenet-normed
                hr_imgs = hr_imgs.to(device)  # (batch_size (1), 3, w, h), in [-1, 1]

                # 前向傳播.
                sr_imgs = model(lr_imgs)  # (1, 3, w, h), in [-1, 1]                

                # 計算 PSNR 和 SSIM
                sr_imgs_y = convert_image(sr_imgs, source='[-1, 1]', target='y-channel').squeeze(
                    0)  # (w, h), in y-channel
                hr_imgs_y = convert_image(hr_imgs, source='[-1, 1]', target='y-channel').squeeze(0)  # (w, h), in y-channel
                psnr = peak_signal_noise_ratio(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                ssim = structural_similarity(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                PSNRs.update(psnr, lr_imgs.size(0))
                SSIMs.update(ssim, lr_imgs.size(0))


        # 輸出平均PSNR和SSIM
        print('PSNR  {psnrs.avg:.3f}'.format(psnrs=PSNRs))
        print('SSIM  {ssims.avg:.3f}'.format(ssims=SSIMs))
        print('平均單張樣本用時  {:.3f} 秒'.format((time.time()-start)/len(test_dataset)))

    print("\n")

最終在三個數據集上的測試結果如下表所示:

  Set5 Set14 BSD100
PSNR 31.866 28.504 27.498
SSIM 0.900 0.797 0.754
單張圖片平均用時(毫秒) 886 304 81

上表中結果與論文中的值基本一致。由於初始化等隨機因素的影響,讀者在復現的時候並不一定與上述值完全一致,較爲接近即可。上述測試值在性能上已經超越了很多算法,例如DRCNN、ESPCN等。這個主要歸功於深度殘差網絡的作用,我們採用了16個殘差模塊進行學習,其對於圖像的特徵表示能力更加顯著。讀者可以自行嘗試進一步再加深網絡模塊數查看效果,本文不再贅述。

5.4 測試

test.py文件完整代碼如下:

from utils import *
from torch import nn
from models import SRResNet
import time
from PIL import Image

# 測試圖像
imgPath = './results/test.jpg'

# 模型參數
large_kernel_size = 9   # 第一層卷積和最後一層卷積的核大小
small_kernel_size = 3   # 中間層卷積的核大小
n_channels = 64         # 中間層通道數
n_blocks = 16           # 殘差模塊數量
scaling_factor = 4      # 放大比例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':

    # 預訓練模型
    #srgan_checkpoint = "./results/checkpoint_srgan.pth"
    srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加載模型SRResNet 或 SRGAN
    checkpoint = torch.load(srresnet_checkpoint)
    srresnet = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    srresnet = srresnet.to(device)
    srresnet.load_state_dict(checkpoint['model'])
   
    srresnet.eval()
    model = srresnet
    # srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
    # srgan_generator.eval()
    # model = srgan_generator

    # 加載圖像
    img = Image.open(imgPath, mode='r')
    img = img.convert('RGB')

    # 雙線性上採樣
    Bicubic_img = img.resize((int(img.width * scaling_factor),int(img.height * scaling_factor)),Image.BICUBIC)
    Bicubic_img.save('./results/test_bicubic.jpg')

    # 圖像預處理
    lr_img = convert_image(img, source='pil', target='imagenet-norm')
    lr_img.unsqueeze_(0)

    # 記錄時間
    start = time.time()

    # 轉移數據至設備
    lr_img = lr_img.to(device)  # (1, 3, w, h ), imagenet-normed

    # 模型推理
    with torch.no_grad():
        sr_img = model(lr_img).squeeze(0).cpu().detach()  # (1, 3, w*scale, h*scale), in [-1, 1]   
        sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
        sr_img.save('./results/test_srres.jpg')

    print('用時  {:.3f} 秒'.format(time.time()-start))

本文從網上選取一張低分辨率的證件照片進行實驗測試。原圖像大小爲130x93,然後分別對其進行雙線性上採樣以及超分重建,圖像放大4倍,變爲520x372。對比效果如下所示:

   

                                 圖14 低分辨率證件照測試效果圖,從左到右依次爲:原圖、bicubic上採樣、超分重建

三.  SRGAN算法原理和Pytorch實現

SRResNet算法是一個單模型算法,從圖像輸入到圖像輸出中間通過各個卷積模塊的操作完成,整個結構比較清晰。但是SRResNet也有不可避免的缺陷,就是它採用了MSE作爲最終的目標函數,而這個MSE是直接通過衡量模型輸出和真值的像素差異來計算的,SRGAN算法指出,這種目標函數會使得超分重建出的圖像過於平滑,儘管PSNR和SSIM值會比較高,但是重建出來的圖像並不能夠很好的符合人眼主觀感受,丟失了細節紋理信息。下面給出一張圖來說明SRResNet算法和SRGAN算法超分重建效果的不同之處:

                        圖15 超分重建效果對比,從左至右分別爲:雙線性插值、SRResNet、SRGAN、真值

從圖上可以看到,原圖因爲分辨率較低,產生了模糊並且丟失了大量的細節信息,雙線性插值無法有效的去模糊,而SRResNet算法儘管能夠一定程度上去除模糊,但是其紋理細節不清晰。最後會發現,SRGAN算法不僅去除了模糊,而且還逼真的重建出了水面上的紋理細節,使得重建的圖片視覺上與真值圖非常吻合。

那怎麼讓模型在紋理細節丟失的情況下“無中生有”的重建出這些信息呢?答案就是生成對抗網絡(Generative Adversarial
Network, GAN)。

1. 生成對抗網絡(GAN)

 GAN的主要靈感來源於博弈論中博弈的思想,應用到深度學習上來說,就是構造兩個深度學習模型:生成網絡G(Generator)和判別網絡D(Discriminator),然後兩個模型不斷博弈,進而使G生成逼真的圖像,而D具有非常強的判斷圖像真僞的能力。生成網絡和判別網絡的主要功能是:

  • G是一個生成式的網絡,它通過某種特定的網絡結構以及目標函數來生成圖像;
  • D是一個判別網絡,判別一張圖片是不是“真實的”,即判斷輸入的照片是不是由G生成;

G的作用就是儘可能的生成逼真的圖像來迷惑D,使得D判斷失敗;而D的作用就是儘可能的挖掘G的破綻,來判斷圖像到底是不是由G生成的“假冒僞劣”。整個過程就好比兩個新手下棋博弈,隨着對弈盤數的增加,一個迷惑手段越來越高明,而另一個甄別本領也越來越強大,最後,兩個新手都變成了高手。這個時候再讓G去和其它的人下棋,可以想到G迷惑的本領已經超越了一衆普通棋手。

以上就是GAN算法的原理。運用在圖像領域,例如風格遷移,超分重建,圖像補全,去噪等,運用GAN可以避免損失函數設計的困難,不管三七二十一,只要有一個基準,直接加上判別器,剩下的就交給對抗訓練。相比其他所有模型, GAN可以產生更加清晰,真實的樣本。

2. 感知損失

爲了防止重建圖像過度平滑,SRGAN重新定義了損失函數,並將其命名爲感知損失(Perceptual loss)。感知損失有兩部分構成:

感知損失=內容損失+對抗損失

對抗損失就是重建出來的圖片被判別器正確判斷的損失,這部分內容跟一般的GAN定義相同。SRGAN的一大創新點就是提出了內容損失,SRGAN希望讓整個網絡在學習的過程中更加關注重建圖片和原始圖片的語義特徵差異,而不是逐個像素之間的顏色亮度差異。以往我們在計算超分重建圖像和原始高清圖像差異的時候是直接在像素圖像上進行比較的,用的MSE準則。SRGAN算法提出者認爲這種方式只會過度的讓模型去學習這些像素差異,而忽略了重建圖像的固有特徵。實際的差異計算應該在圖像的固有特徵上計算。但是這種固有特徵怎麼表示呢?其實很簡單,已經有很多模型專門提出來提取圖像固有特徵然後進行分類等任務。我們只需要把這些模型中的特徵提取模塊截取出來,然後去計算重建圖像和原始圖像的特徵,這些特徵就是語義特徵了,然後再在特徵層上進行兩幅圖像的MSE計算。在衆多模型中,SRGAN選用了VGG19模型,其截取的模型命名爲truncated_vgg19。所謂模型截斷,也就是隻提取原始模型的一部分,然後作爲一個新的單獨的模型進行使用。

至此重新整理下內容損失計算方式:

  • 通過SRResNet模型重建出高清圖像SR;
  • 通過truncated_vgg19模型對原始高清圖像H和重建出的高清圖像SR分別進行計算,得到兩幅圖像對應的特徵圖H_fea和SR_fea;
  • 計算H_fea和SR_fea的MSE值;

從上述計算方式上看出,原來的計算方式是直接計算H和SR的MSE值,而改用新的內容損失後只需要利用truncated_vgg19模型對圖像多作一次推理得到特徵圖,再在特徵圖上進行計算。

3. SRGAN結構剖析

SRGAN分爲兩部分:生成器模型(Generator)和判別器模型(Discriminator)。

生成器模型採用了SRResNet完全一樣的結構,只是在計算損失函數時需要利用截斷的VGG19模型進行計算。這裏注意,截斷的VGG19模型只是用來計算圖像特徵,其本身並不作爲一個子模塊加在生成器後面。可以將此處的VGG19模型理解爲靜止的(梯度不更新的),只是用它來計算一下特徵而已,其使用與一般的圖像濾波器sobel、canny算子等類似。

判別器模型結構如下所示:

                                                                                        圖16 SRGAN判別器模型

判別器模型對原始高清圖像或者重建的高清圖像進行判斷,判斷圖像到底是不是生成器創造出來。本質上是一個分類模型,因此判別器的最終輸出是一個1維的張量。判別器模型中間部分使用了多個卷積模塊進行特徵提取,這部分內容並沒有特別之處,因此本文不再對該模型進行闡述,讀者自行對照代碼理解即可。

4. Pytorch實現

Pytorch實現沿用前面SRResNet的設計框架。由於SRGAN算法的生成器部分採用的是與SRResNet模型完全一樣的結構,因此我們在訓練時就可以直接使用前面訓練好的SRResNet模型對生成器進行初始化以加快整個算法的收斂。

4.1 訓練

(1)訓練腳本train_srgan.py

參照SRResNet,完整的訓練代碼如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from models import Generator, Discriminator, TruncatedVGG19
from datasets import SRDataset
from utils import *


# 數據集參數
data_folder = './data/'    # 數據存放路徑
crop_size = 96             # 高分辨率圖像裁剪尺寸
scaling_factor = 4         # 放大比例

# 生成器模型參數(與SRResNet相同)
large_kernel_size_g = 9   # 第一層卷積和最後一層卷積的核大小
small_kernel_size_g = 3   # 中間層卷積的核大小
n_channels_g = 64         # 中間層通道數
n_blocks_g = 16           # 殘差模塊數量
srresnet_checkpoint = "./results/checkpoint_srresnet.pth"  # 預訓練的SRResNet模型,用來初始化

# 判別器模型參數
kernel_size_d = 3  # 所有卷積模塊的核大小
n_channels_d = 64  # 第1層卷積模塊的通道數, 後續每隔1個模塊通道數翻倍
n_blocks_d = 8     # 卷積模塊數量
fc_size_d = 1024  # 全連接層連接數

# 學習參數
batch_size = 128    # 批大小
start_epoch = 1     # 迭代起始位置
epochs = 50         # 迭代輪數
checkpoint = None   # SRGAN預訓練模型, 如果沒有則填None
workers = 4         # 加載數據線程數量
vgg19_i = 5         # VGG19網絡第i個池化層
vgg19_j = 4         # VGG19網絡第j個卷積層
beta = 1e-3         # 判別損失乘子
lr = 1e-4           # 學習率

# 設備參數
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 2                 # 用來運行的gpu數量
cudnn.benchmark = True   # 對卷積進行加速
writer = SummaryWriter() # 實時監控     使用命令 tensorboard --logdir runs  進行查看


def main():
    """
    訓練.
    """
    global checkpoint,start_epoch,writer

    # 模型初始化
    generator = Generator(large_kernel_size=large_kernel_size_g,
                              small_kernel_size=small_kernel_size_g,
                              n_channels=n_channels_g,
                              n_blocks=n_blocks_g,
                              scaling_factor=scaling_factor)

    discriminator = Discriminator(kernel_size=kernel_size_d,
                                    n_channels=n_channels_d,
                                    n_blocks=n_blocks_d,
                                    fc_size=fc_size_d)

    # 初始化優化器
    optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad,generator.parameters()),lr=lr)
    optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad,discriminator.parameters()),lr=lr)

    # 截斷的VGG19網絡用於計算損失函數
    truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
    truncated_vgg19.eval()

    # 損失函數
    content_loss_criterion = nn.MSELoss()
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()

    # 將數據移至默認設備
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    truncated_vgg19 = truncated_vgg19.to(device)
    content_loss_criterion = content_loss_criterion.to(device)
    adversarial_loss_criterion = adversarial_loss_criterion.to(device)
    
    # 加載預訓練模型
    srresnetcheckpoint = torch.load(srresnet_checkpoint)
    generator.net.load_state_dict(srresnetcheckpoint['model'])

    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        generator.load_state_dict(checkpoint['generator'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer_g.load_state_dict(checkpoint['optimizer_g'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
    
    # 單機多GPU訓練
    if torch.cuda.is_available() and ngpu > 1:
        generator = nn.DataParallel(generator, device_ids=list(range(ngpu)))
        discriminator = nn.DataParallel(discriminator, device_ids=list(range(ngpu)))

    # 定製化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='imagenet-norm')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 開始逐輪訓練
    for epoch in range(start_epoch, epochs+1):
        
        if epoch == int(epochs / 2):  # 執行到一半時降低學習率
            adjust_learning_rate(optimizer_g, 0.1)
            adjust_learning_rate(optimizer_d, 0.1)

        generator.train()   # 開啓訓練模式:允許使用批樣本歸一化
        discriminator.train()

        losses_c = AverageMeter()  # 內容損失
        losses_a = AverageMeter()  # 生成損失
        losses_d = AverageMeter()  # 判別損失

        n_iter = len(train_loader)

        # 按批處理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 數據移至默認設備進行訓練
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24),  imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  imagenet-normed 格式

            #-----------------------1. 生成器更新----------------------------
            # 生成
            sr_imgs = generator(lr_imgs)  # (N, 3, 96, 96), 範圍在 [-1, 1]
            sr_imgs = convert_image(
                sr_imgs, source='[-1, 1]',
                target='imagenet-norm')  # (N, 3, 96, 96), imagenet-normed

            # 計算 VGG 特徵圖
            sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs)              # batchsize X 512 X 6 X 6
            hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach()     # batchsize X 512 X 6 X 6

            # 計算內容損失
            content_loss = content_loss_criterion(sr_imgs_in_vgg_space,hr_imgs_in_vgg_space)

            # 計算生成損失
            sr_discriminated = discriminator(sr_imgs)  # (batch X 1)   
            adversarial_loss = adversarial_loss_criterion(
                sr_discriminated, torch.ones_like(sr_discriminated)) # 生成器希望生成的圖像能夠完全迷惑判別器,因此它的預期所有圖片真值爲1

            # 計算總的感知損失
            perceptual_loss = content_loss + beta * adversarial_loss

            # 後向傳播.
            optimizer_g.zero_grad()
            perceptual_loss.backward()

            # 更新生成器參數
            optimizer_g.step()

            #記錄損失值
            losses_c.update(content_loss.item(), lr_imgs.size(0))
            losses_a.update(adversarial_loss.item(), lr_imgs.size(0))

            #-----------------------2. 判別器更新----------------------------
            # 判別器判斷
            hr_discriminated = discriminator(hr_imgs)
            sr_discriminated = discriminator(sr_imgs.detach())

            # 二值交叉熵損失
            adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \
                            adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated))  # 判別器希望能夠準確的判斷真假,因此凡是生成器生成的都設置爲0,原始圖像均設置爲1

            # 後向傳播
            optimizer_d.zero_grad()
            adversarial_loss.backward()

            # 更新判別器
            optimizer_d.step()

            # 記錄損失
            losses_d.update(adversarial_loss.item(), hr_imgs.size(0))

            # 監控圖像變化
            if i==(n_iter-2):
                writer.add_image('SRGAN/epoch_'+str(epoch)+'_1', make_grid(lr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRGAN/epoch_'+str(epoch)+'_2', make_grid(sr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRGAN/epoch_'+str(epoch)+'_3', make_grid(hr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)

            # 打印結果
            print("第 "+str(i)+ " 個batch結束")
 
        # 手動釋放內存              
        del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated  # 手工清除掉緩存

        # 監控損失值變化
        writer.add_scalar('SRGAN/Loss_c', losses_c.val, epoch) 
        writer.add_scalar('SRGAN/Loss_a', losses_a.val, epoch)    
        writer.add_scalar('SRGAN/Loss_d', losses_d.val, epoch)    

        # 保存預訓練模型
        torch.save({
            'epoch': epoch,
            'generator': generator.module.state_dict(),
            'discriminator': discriminator.module.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
        }, 'results/checkpoint_srgan.pth')
    
    # 訓練結束關閉監控
    writer.close()


if __name__ == '__main__':
    main()

相關定義模型在models.py中給出:

class Generator(nn.Module):
    """
    生成器模型,其結構與SRResNet完全一致.
    """

    def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
        """
        參數 large_kernel_size:第一層和最後一層卷積核大小
        參數 small_kernel_size:中間層卷積核大小
        參數 n_channels:中間層卷積通道數
        參數 n_blocks: 殘差模塊數量
        參數 scaling_factor: 放大比例
        """
        super(Generator, self).__init__()
        self.net = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                            n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)

    def forward(self, lr_imgs):
        """
        前向傳播.

        參數 lr_imgs: 低精度圖像 (N, 3, w, h)
        返回: 超分重建圖像 (N, 3, w * scaling factor, h * scaling factor)
        """
        sr_imgs = self.net(lr_imgs)  # (N, n_channels, w * scaling factor, h * scaling factor)

        return sr_imgs


class Discriminator(nn.Module):
    """
    SRGAN判別器
    """

    def __init__(self, kernel_size=3, n_channels=64, n_blocks=8, fc_size=1024):
        """
        參數 kernel_size: 所有卷積層的核大小
        參數 n_channels: 初始卷積層輸出通道數, 後面每隔一個卷積層通道數翻倍
        參數 n_blocks: 卷積塊數量
        參數 fc_size: 全連接層連接數
        """
        super(Discriminator, self).__init__()

        in_channels = 3

        # 卷積系列,參照論文SRGAN進行設計
        conv_blocks = list()
        for i in range(n_blocks):
            out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
            conv_blocks.append(
                ConvolutionalBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu'))
            in_channels = out_channels
        self.conv_blocks = nn.Sequential(*conv_blocks)

        # 固定輸出大小
        self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6))

        self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)

        self.leaky_relu = nn.LeakyReLU(0.2)

        self.fc2 = nn.Linear(1024, 1)

        # 最後不需要添加sigmoid層,因爲PyTorch的nn.BCEWithLogitsLoss()已經包含了這個步驟

    def forward(self, imgs):
        """
        前向傳播.

        參數 imgs: 用於作判別的原始高清圖或超分重建圖,張量表示,大小爲(N, 3, w * scaling factor, h * scaling factor)
        返回: 一個評分值, 用於判斷一副圖像是否是高清圖, 張量表示,大小爲 (N)
        """
        batch_size = imgs.size(0)
        output = self.conv_blocks(imgs)
        output = self.adaptive_pool(output)
        output = self.fc1(output.view(batch_size, -1))
        output = self.leaky_relu(output)
        logit = self.fc2(output)

        return logit


class TruncatedVGG19(nn.Module):
    """
    truncated VGG19網絡,用於計算VGG特徵空間的MSE損失
    """

    def __init__(self, i, j):
        """
        :參數 i: 第 i 個池化層
        :參數 j: 第 j 個卷積層
        """
        super(TruncatedVGG19, self).__init__()

        # 加載預訓練的VGG模型
        vgg19 = torchvision.models.vgg19(pretrained=True)  # C:\Users\Administrator/.cache\torch\checkpoints\vgg19-dcbb9e9d.pth

        maxpool_counter = 0
        conv_counter = 0
        truncate_at = 0
        # 迭代搜索
        for layer in vgg19.features.children():
            truncate_at += 1

            # 統計
            if isinstance(layer, nn.Conv2d):
                conv_counter += 1
            if isinstance(layer, nn.MaxPool2d):
                maxpool_counter += 1
                conv_counter = 0

            # 截斷位置在第(i-1)個池化層之後(第 i 個池化層之前)的第 j 個卷積層
            if maxpool_counter == i - 1 and conv_counter == j:
                break

        # 檢查是否滿足條件
        assert maxpool_counter == i - 1 and conv_counter == j, "當前 i=%d 、 j=%d 不滿足 VGG19 模型結構" % (
            i, j)

        # 截取網絡
        self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])

    def forward(self, input):
        """
        前向傳播
        參數 input: 高清原始圖或超分重建圖,張量表示,大小爲 (N, 3, w * scaling factor, h * scaling factor)
        返回: VGG19特徵圖,張量表示,大小爲 (N, feature_map_channels, feature_map_w, feature_map_h)
        """
        output = self.truncated_vgg19(input)  # (N, feature_map_channels, feature_map_w, feature_map_h)

        return output

在代碼中已經給出了詳細的註釋,讀者在運行調試時結合註釋相信可以快速的理解整個處理流程。

(2)訓練結果

下圖分別展示了整個訓練過程中內容損失、生成損失和判別損失的變化曲線。

 

                                                                                                      圖18 損失函數變化曲線

從上圖中可以看到,相對SRResNet的收斂曲線,SRGAN非常不平穩,判別損失和生成損失此消彼長,這說明判別器和生成器正在做着激烈的對抗。一般來說,生成對抗網絡相比普通的網絡其訓練難度更大,我們無法通過查看loss來說明gan訓練得怎麼樣。目前也有不少文獻開始嘗試解決整個問題,使得GAN算法的訓練進程可以更加明顯。

儘管不能從loss損失函數變化曲線上看出訓練進程,我們還可以從每次epoch的訓練樣本重建效果上進行查看。下圖分別顯示了epoch=1、25和50 部分訓練樣本重建效果圖,第一行爲低分辨率圖,第二行爲超分重建圖,第三行爲原始高清圖。

                                                                                          圖19 epoch=1時訓練結果圖

                                                                                           圖20 epoch=25時訓練結果圖

                                                                                   圖21 epoch=50時訓練結果圖

從訓練圖上可以看到,在epoch=50即訓練結束的時候,其生成到的超分圖已經非常接近原始高清圖,重建出的圖像視覺感受更加突出,細節較豐富,相比SRResNet的過度平滑,其生成的圖像更符合真實場景效果。

4.2 評估

完整的評估代碼如下eval.py:

from utils import *
from torch import nn
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from datasets import SRDataset
from models import SRResNet,Generator
import time

# 模型參數
large_kernel_size = 9   # 第一層卷積和最後一層卷積的核大小
small_kernel_size = 3   # 中間層卷積的核大小
n_channels = 64         # 中間層通道數
n_blocks = 16           # 殘差模塊數量
scaling_factor = 4      # 放大比例
ngpu = 2                # GP數量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':
    
    # 測試集目錄
    data_folder = "./data/"
    test_data_names = ["Set5","Set14", "BSD100"]

    # 預訓練模型
    srgan_checkpoint = "./results/checkpoint_srgan.pth"
    #srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加載模型SRResNet 或 SRGAN
    checkpoint = torch.load(srgan_checkpoint)
    generator = Generator(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    generator = generator.to(device)
    generator.load_state_dict(checkpoint['generator'])

    # 多GPU測試
    if torch.cuda.is_available() and ngpu > 1:
        generator = nn.DataParallel(generator, device_ids=list(range(ngpu)))
   
    generator.eval()
    model = generator
    # srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
    # srgan_generator.eval()
    # model = srgan_generator

    for test_data_name in test_data_names:
        print("\n數據集 %s:\n" % test_data_name)

        # 定製化數據加載器
        test_dataset = SRDataset(data_folder,
                                split='test',
                                crop_size=0,
                                scaling_factor=4,
                                lr_img_type='imagenet-norm',
                                hr_img_type='[-1, 1]',
                                test_data_name=test_data_name)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1,
                                                pin_memory=True)

        # 記錄每個樣本 PSNR 和 SSIM值
        PSNRs = AverageMeter()
        SSIMs = AverageMeter()

        # 記錄測試時間
        start = time.time()

        with torch.no_grad():
            # 逐批樣本進行推理計算
            for i, (lr_imgs, hr_imgs) in enumerate(test_loader):
                
                # 數據移至默認設備
                lr_imgs = lr_imgs.to(device)  # (batch_size (1), 3, w / 4, h / 4), imagenet-normed
                hr_imgs = hr_imgs.to(device)  # (batch_size (1), 3, w, h), in [-1, 1]

                # 前向傳播.
                sr_imgs = model(lr_imgs)  # (1, 3, w, h), in [-1, 1]                

                # 計算 PSNR 和 SSIM
                sr_imgs_y = convert_image(sr_imgs, source='[-1, 1]', target='y-channel').squeeze(
                    0)  # (w, h), in y-channel
                hr_imgs_y = convert_image(hr_imgs, source='[-1, 1]', target='y-channel').squeeze(0)  # (w, h), in y-channel
                psnr = peak_signal_noise_ratio(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                ssim = structural_similarity(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                PSNRs.update(psnr, lr_imgs.size(0))
                SSIMs.update(ssim, lr_imgs.size(0))


        # 輸出平均PSNR和SSIM
        print('PSNR  {psnrs.avg:.3f}'.format(psnrs=PSNRs))
        print('SSIM  {ssims.avg:.3f}'.format(ssims=SSIMs))
        print('平均單張樣本用時  {:.3f} 秒'.format((time.time()-start)/len(test_dataset)))

    print("\n")

最終在三個數據集上的測試結果如下表所示:

  Set5 Set14 BSD100
PSNR 29.021 25.652 24.833
SSIM 0.839 0.693 0.650
單張圖片平均用時(毫秒) 850 650 80

上表中結果與論文中的值較爲接近。可以看到,其PSNR和SSIM效果並不好,這是因爲SRGAN本質上就不是爲了PSNR和SSIM指標而設計優化的。SRGAN論文中也指出使用PSNR和SSIM會讓算法過度平滑,在視覺效果上並不理想。爲了定量評價SRGAN算法效果,論文中新設計了MOS(Mean Option Score)指標,簡單來說,就是讓多位觀察者採用主觀評價的方式對重建效果圖進行打分,最後將分數作平均並以此作爲評價指標。

4.3 測試

完整的測試代碼如下test.py:

from utils import *
from torch import nn
from models import SRResNet,Generator
import time
from PIL import Image

# 測試圖像
imgPath = './results/test.jpg'

# 模型參數
large_kernel_size = 9   # 第一層卷積和最後一層卷積的核大小
small_kernel_size = 3   # 中間層卷積的核大小
n_channels = 64         # 中間層通道數
n_blocks = 16           # 殘差模塊數量
scaling_factor = 4      # 放大比例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':

    # 預訓練模型
    srgan_checkpoint = "./results/checkpoint_srgan.pth"
    #srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加載模型SRResNet 或 SRGAN
    checkpoint = torch.load(srgan_checkpoint)
    generator = Generator(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    generator = generator.to(device)
    generator.load_state_dict(checkpoint['generator'])
   
    generator.eval()
    model = generator

    # 加載圖像
    img = Image.open(imgPath, mode='r')
    img = img.convert('RGB')

    # 雙線性上採樣
    Bicubic_img = img.resize((int(img.width * scaling_factor),int(img.height * scaling_factor)),Image.BICUBIC)
    Bicubic_img.save('./results/test_bicubic.jpg')

    # 圖像預處理
    lr_img = convert_image(img, source='pil', target='imagenet-norm')
    lr_img.unsqueeze_(0)

    # 記錄時間
    start = time.time()

    # 轉移數據至設備
    lr_img = lr_img.to(device)  # (1, 3, w, h ), imagenet-normed

    # 模型推理
    with torch.no_grad():
        sr_img = model(lr_img).squeeze(0).cpu().detach()  # (1, 3, w*scale, h*scale), in [-1, 1]   
        sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
        sr_img.save('./results/test_srgan.jpg')

    print('用時  {:.3f} 秒'.format(time.time()-start))

測試效果如下:

 

 

圖22 低分辨率證件照測試效果圖,從左到右從上到下依次爲:原圖、bicubic上採樣、SRResNet超分重建、SRGAN超分重建

四. 總結

本文詳細講述了超分重建的概念和研究進展。針對有代表性的SRResNet和SRGAN算法,分別進行了原理剖析,並給出了Pytorch實現代碼。從最終的效果上來看,達到了原論文裏的效果。如果讀者希望能夠較好的入門超分重建領域,那麼可以從本文出發,在掌握原理的基礎上按照本文示例自己動手完成算法建模,可以爲自己在超分重建領域打下良好的基礎。當然,如果讀者並不是研究超分重建領域,那麼本文也可以作爲一個實戰案例,學習生成對抗網絡的操作技巧。本文在組織代碼時力求簡單明瞭,並不希望成爲一個負擔較重的“工程”,在簡潔的基礎上儘量“傻瓜式”,這樣也方便讀者可以任意的對其進行擴展和操作。

由於水平有限,本文肯定有不少理解上的錯誤或者是代碼實現上的問題,還請讀者能夠多多指正,共同進步!

下一篇博文打算研究圖像語義分割、摳圖領域,同樣會以簡潔實用爲目標,有興趣的讀者後面可以繼續關注!

 參考文獻

【1】Dong C, Loy C C, He K, et al. Image Super-Resolution Using Deep Convolutional Networks[C]. Computer Vision and Pattern Recognition, 2014.

【2】Shi W, Caballero J, Huszar F, et al. Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network[C]. Computer Vision and Pattern Recognition, 2016: 1874-1883.

【3】Ledig C, Theis L, Huszar F, et al. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network[C]. Computer Vision and Pattern Recognition, 2017.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章