FasterViT:英偉達提出分層注意力,構造高吞吐CNN-ViT混合網絡 | ICLR 2024

論文設計了新的CNN-ViT混合神經網絡FasterViT,重點關注計算機視覺應用的圖像吞吐能力。FasterViT結合CNN的局部特徵學習的特性和ViT的全局建模特性,引入分層注意力(HAT)方法在降低計算成本的同時增加窗口間的交互。在包括分類、對象檢測和分割各種CV任務上,FasterViT在精度與圖像吞吐量方面實現了SOTAHAT可用作即插即用的增強模塊

來源:曉飛的算法工程筆記 公衆號

論文: FasterViT: Fast Vision Transformers with Hierarchical Attention

Introduction


ViT最近在計算機視覺領域變得流行,並在圖像分類、目標檢測和語義分割等各種應用中取得了卓越的性能。儘管如此,純ViT模型由於缺乏歸納偏置,導致需要更多的訓練數據並可能影響性能。而由CNNViT組成的混合架構則可以解決這個問題並達到有競爭力的性能,無需大規模訓練數據集或知識蒸餾等其他技術。

ViT的一個組成部分是自注意力機制,可以對短距離和長距離的空間關係進行建模。但由於自注意力的二次計算複雜度會顯着影響效率,阻礙其在高分辨率圖像應用中的使用。此外,與原始ViT模型的架構(即固定分辨率,無下采樣)相反,以多尺度的方式學習特徵通常會產生更好的性能,特別是對於下游應用(如檢測、分割)。

  爲了解決這些問題,Swin Transformer提出了一種多尺度架構,其中自注意力在局部窗口內計算,通過窗口移動保證不同區域之間的交互。但由於局部區域的感受域有限且窗口移動的覆蓋範圍較小,跨窗口交互和長距離空間關係的建模在高分辨率輸入的任務中依然具有挑戰性。此外,在早期分辨率較大階段也可能會由於局部窗口數量的增加而影響圖像吞吐量。最近,Swin Transformer V2通過改進自注意力機制來解決高分辨率圖像訓練的不穩定問題。但與Swin Transformer相比,除了較低的圖像吞吐量之外,Swin Transformer V2仍然依賴原始的窗口移動機制來進行不同窗口的交互,這在處理大圖像時依然不高效。

  於是,論文提出一種專爲高分辨率圖像輸入量身定製的FasterViT混合架構,能保持較大的圖像吞吐。FasterViT由四個不同的階段組成,在高分辨率階段(即階段 1、2)使用殘差卷積塊,在後續階段(即階段 3、4)使用Transformer塊,階段之間通過步長卷積層來降低輸入圖像分辨率以及加倍通道數。這樣的架構可以快速生成高質量token,然後基於Transformer塊來進一步處理這些token。對於每個Transformer塊,論文使用分層注意力塊來提取長短距離空間關係,進行有效的跨窗口交互。

  如圖2所示,分層注意力機制爲每個局部窗口學習一個carrier token作爲總結,然後基於carrier token對窗口之間的交互模式進行建模。由於有基於局部窗口的注意力作爲計算約束,隨着區域數量的增加,分層注意力的計算複雜度幾乎隨輸入圖像分辨率線性增長。因此,它是捕獲高分辨率特徵的遠距離關係的高效且有效的方法。

  論文在各種圖像任務和數據集上廣泛地驗證了所提出的FasterViT模型的有效性,考慮性能和圖像吞吐量之間的權衡,FasterViTImageNet-1K top-1實現了最先進的性能。爲了展示FasterViT對於更大數據集的可擴展性,論文還在ImageNet-21K數據集上對FasterViT進行了預訓練,並在更大規模和更大分辨率的任務上進行微調和評估,實現了最先進的性能。

  論文的貢獻總結如下:

  • 推出新穎的FasterViT混合視覺架構,旨在實現性能和圖像吞吐之間的最佳平衡,可以針對不同的數據集和模型大小有效地縮放到更高分辨率的輸入圖像。
  • 提出了分層注意力模塊,可以有效地捕獲局部區域的跨窗口交互,並對長距離空間關係進行建模。
  • FasterViT在圖像吞吐和準確性之間的權衡上實現了新的SOTA,比基於ViT的同類架構和最新的SOTA模型要快得多。同時,在MS COCO數據集上的檢測和實例分割以及ADE20K數據集上的語義分割達到了具有競爭力的性能。

FasterViT


Design Principals

  論文專注於在主流硬件上實現計算機視覺任務的最高吞吐量,需要在數據傳輸和計算之間進行仔細的平衡,以最大限度地提高吞吐量。

  在分層視覺模型中,中間特徵的空間維度隨着推理的進行而縮小。初始網絡層具有較大的空間維度和較少的通道(例如 \(112\times 112 \times 64\)),導致可選操作受內存傳輸限制,應該多使用如密集卷積的計算密集型操作。此外,無法以矩陣形式表示的操作(例如非線性、池化、批量歸一化)也是受內存限制的,應儘量減少使用。相反,後面的層則往往受到運算量限制。比如分層CNN具有大小爲 \(14\times 14\) 的高維特徵,爲使用提取能力更強的操作(例如層歸一化、SE或注意力)留下了空間,而且對吞吐量的影響相當小。

Architecture

  整體設計如圖 3 所示,在早期階段使用卷積層處理高分辨率特徵,後半部分依賴於新穎的分層注意力層來對整個特徵圖進行空間推理。在此設計中,論文根據計算量和吞吐量優化了架構。前半部分和下采樣塊使用了密集卷積,避免使用SE算子。同時,需要最小化高分辨率階段(即 1、2)的層歸一化使用,因爲這些層往往受到內存傳輸限制。而後期階段(即 3、4)通常會受到計算量限制,與內存傳輸成本相比,GPU硬件在計算上花費更多時間,應用多頭注意力也不會成爲瓶頸。

FasterViT Components

  • Stem

  輸入圖像 \({\textbf{x}}\in{\mathrm{R}^{H\times W\times3}}\) 通過連續的兩個 \(3\times3\) 卷積層投影爲 \(D\)embedding,每個卷積層的步長爲2。embedding會進一步批歸一化,每次卷積後都會使用ReLU激活函數。

  • Downsampler Blocks

FasterViT的下采樣塊先對空間特徵應用2D層歸一化,然後使用內核爲 \(3\times3\) 且步長爲 2 的卷積層,將空間分辨率降低2倍。

  • Conv Blocks

  階段 1 和階段 2 由殘差卷積塊組成,定義爲

\[\begin{array}{l}{{\hat{\mathbf{x}}=\mathbf{G}\mathrm{E}\mathrm{L}\mathrm{U}(\mathrm{BN}(\mathrm{Conv_{3\times 3}}(\mathrm{x}))),}}\\ {{\mathbf{x}=\mathrm{BN}(\mathrm{Conv}_{3\times 3}(\hat{\mathbf{x}}))+\mathbf{x}}}\end{array} \quad\quad (1) \]

  其中BN表示批歸一化。遵循設計原則,這些卷積是密集的。

  • Hierarchical Attention

  在這項工作中,論文提出了一種新穎的窗口注意力模塊,整體如圖 2 所示,詳細介紹如圖 4 所示。核心是在Swin Transformer的局部窗口上引入carrier tokensCT)用於彙總局部窗口的信息,隨後基於CT進行局部窗口之間的信息交互。

  假設論文給出一個輸入特徵圖 \(\mathbf{x}\in\mathbb{R}^{{H}\times W\times d}\),其中 \(\textstyle H\)\(\dot{W}\)\(d\) 表示特徵圖的高度、寬度和維度。爲了簡單起見,設置\(H=W\)。以 \(n = \frac{H^{2}}{k^{2}}\) 將輸入特徵圖劃分爲 \(n\times n\) 個局部窗口,其中 \(k\) 是窗口大小,如下所示:

\[{\hat{\mathbf{x}}}_\mathbf{l}=\mathbf{Split}_{k\times k}(\mathbf{x}) \quad\quad (2) \]

  通過池化每個窗口得到 \(L=2^{c}\)token來初始化CT

\[\begin{array}{l}{{\hat{\bf x}_{\mathrm{c}}=\mathbf{Conv}_{3\times 3}({\bf x}),}}\\ {{\hat{\bf x}_{\mathrm{ct}}=\mathrm{AvgPool}_{H^{2}\to n^{2}L}(\hat{\bf x}_{\mathrm{c}}),}}\end{array} \quad\quad (3) \]

  其中 \(\mathbf{Conv}_{3\times 3}\)Twins中使用的高效位置編碼,\(\hat{\bf x}_{\mathrm{ct}}\)AvgPool分別表示carrier token和特徵池化操作。這些池化的token代表了各自局部窗口的總結,一般都有 \(L << k\),論文將c設置爲1CT的初始化在每個階段僅執行一次,每個局部窗口 \({\hat{\mathbf{x}}}_{l}\) 都有唯一的CT\(\hat{\bf x}_{\mathrm{ct},1}\),構成\(\hat{\bf x}_{\mathrm{ct}}\:=\:\{\hat{\bf x}_{\mathrm{ct},1}\}_{1=0}^{n}.\)集合。

  在每個HAT塊中,CT都會經歷以下注意力處理:

\[\begin{array}{l}{{\hat{\mathbf{x}}_{\mathrm{ct}}=\hat{\mathbf{x}}_{\mathrm{ct}}+\gamma_{1}\cdot{\mathbf{M H S A}}(\mathbf{LN}(\hat{\mathbf{x}}_{\mathrm{ct}})),}}\\ {{\hat{\mathbf{x}}_{\mathrm{ct}}=\hat{\mathbf{x}}_{\mathrm{ct}}+\gamma_{2}\cdot{\mathbf{M L P}}_{d\to4d\to d}(\mathbf{LN}(\hat{x}_{\mathrm{ct}})),}}\end{array} \quad\quad (4) \]

  其中LN表示層歸一化,MHSA表示多頭自注意力,\(\gamma\) 是可學習的每個通道特定的縮放因子,\(\mathbf{MLP}_{d\rightarrow4d\rightarrow d}\) 是帶有GeLU激活函數的兩層MLP結構。

  接下來,爲了對長短距離空間信息進行建模,論文需要進行局部token\(\hat{\mathbf{x}}_{l}\)carrier token\({\hat{\mathbf{x}}}_{\mathrm{ct,l}}\) 之間的交互信息。

  首先,將局部特徵和CT連接起來,每個局部窗口只能訪問其相應的CT

\[\hat{\bf x}_{\mathrm{w}}=\mathbf{Concat}(\hat{\bf x}_{l},\hat{\bf x}_{\mathrm{ct,l}}) \quad\quad (5) \]

  隨後進行另一組注意力處理:

\[\begin{array}{l}{{\hat{\mathbf{x}}_{\mathrm{w}}=\hat{\mathbf{x}}_{\mathrm{w}}+\gamma_{1}\cdot{\mathbf{M H S A}}(\mathbf{LN}(\hat{\mathbf{x}}_{\mathrm{w}})),}}\\ {{\hat{\mathbf{x}}_{\mathrm{w}}=\hat{\mathbf{x}}_{\mathrm{w}}+\gamma_{2}\cdot{\mathbf{M L P}}_{d\to4d\to d}(\mathbf{LN}(\hat{x}_{\mathrm{w}})),}}\end{array} \quad\quad (6) \]

  最後,token被進一步拆分回局部特徵和CT,用於後續的分層注意力層:

\[\hat{\mathbf{x}}_{l},\hat{\mathbf{x}}_{\mathrm{ct.1}}=\mathbf{Spilt}(\hat{\mathbf{x}}_{\mathrm{w}}) \quad\quad (7) \]

  公式 4-7 在階段中的迭代執行,爲了進一步促進長距離交互,論文在階段末尾設計了全局信息傳播計算如下:

\[{\bf x}=\mathbf{Upsample}_{n^{2}L\to H^{2}}(\hat{\bf x}_{\mathrm{ct},l})+\mathbf{Merge}_{n^{2}k^{2}\to H^{2}}(\hat{\bf x}_{l}) \quad\quad (8) \]

  在公式 4 和 6 中,MHSA具有token位置不變性,但顯然特徵在空間維度中的位置能提供更豐富的信息。爲了解決這個問題,論文效仿SwinV2採用兩層MLP2D絕對位置信息嵌入到CT和局部窗口token中。爲了促進類似圖像的局部歸納偏差,論文還使用SwinV2的對數空間的相對位置偏差來增強注意力計算,確保token的相對位置有助於注意力學習。因爲位置編碼由MLP插值,這種方法對圖像大小變化是具有靈活性的,經過訓練的模型可以應用於任何輸入分辨率。

  多種全局-局部自注意力之間的比較如圖 5 所示,分層注意力將全局注意力分爲局部注意力和次全局注意力,兩者都可壓縮爲 2 個密集註意力。CT參與雙方的關注並促進信息交換。

  • Complexity Analysis of HAT

  最傳統和流行的完全注意力的複雜性是 \(O(H^{4}d)\),將特徵大小劃分爲大小爲 \(k\) 的窗口並運行注意力,能簡化到 \(O(k^2H^{2}d)\)

  衆所周知,窗口注意力更有效但缺乏全局特徵交互。論文基於CT在整個特徵圖上進行總結和交互,以彌補全局交互的缺失。給定每個窗口的 \(L\)CT,局部窗口計算的複雜度爲 \(O((k^{2}+L)H^{2}d)\)CT注意力計算的複雜度爲 \(O((\frac{H^{2}}{k^{2}}L)^{2}d)\),兩種注意力的總成本爲 \(O(k^{2}H^{2}d+LH^{2}d+\frac{H^{4}}{k^{4}}L^{2}d).\)

  多級注意力的另一種方式爲局部注意力提供子採樣的全局信息,如Twins對全局特徵圖進行二次採樣並將其用作局部窗口注意力的鍵和值,複雜度爲\(O(k^{2}H^{2}d+\frac{H^{4}}{k^{2}}d)\)。在相同大小的局部窗口(\(k\))和 \(H\) 下,HAT的複雜度爲 \(O(L\ +\ {\frac{H^{2}L^{2}}{k^{4}}})\)Twins的複雜度爲 \(O\bigl({\frac{H^{2}}{k^{2}}}\bigr)\)。分辨率越高,HAT的效率越高。對於 \(H=32\)\(k=8\) ,當 \(L=4\) 時,HAT\(O(8)\),而Twins\(O(16)\)

Experiments


Image Classification

  表 1 中展示了FasterViT模型其它模型在ImageNet-1K數據集表現。

  爲了驗證所提出模型的可擴展性,論文在ImageNet-21K數據集上預訓練FasterViT-4,並在ImageNet-1K數據集上對各種圖像分辨率進行微調。一般來說,與其他同類模型相比,FasterViT-4具有更好的精度-吞吐量權衡。

Object Detection and Instance Segmentation

  表 3 展示了使用Cascade Mask R-CNN網絡在MS COCO數據集上的對象檢測和實例分割基準。與其他模型相比,FasterViT模型作爲主幹會具有更好的精度-吞吐量權衡。

Semantic Segmentation

  表 5 展示了使用UPerNet網絡在ADE20K數據集上的語義分割基準。與之前的任務類似,FasterViT模型同樣有更好的性能與吞吐量權衡。

Component-wise study



如果本文對你有幫助,麻煩點個贊或在看唄~
更多內容請關注 微信公衆號【曉飛的算法工程筆記】

work-life balance.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章