TensorFlow計算圖優化代碼剖析

代碼路徑:tensorflow/core/grappler/optimizers
其中meta_optimizer.cc中的RunMetaOptimizer方法的調用觸發對圖的不同類型的優化操作.

優化操作分爲一下幾類:
1. pruning.裁剪,比如移除一些無用的操作(一旦圖建立之後不再使用的stop gradient節點以及Identity節點),優化梯度計算.
2. constfold.常量打包.
3. layout. 對tensor的layout針對計算庫以及設備進行調整.比如cudnn使用NCHW比較高效.
4. memory.
5. arithmetic.
6. autoparallel.
以上optimizer均可以同時使用.
下面我們對以上六種圖優化手段逐一進行代碼級剖析.

pruning

ModelPruner類有三個成員函數, name()方法返回名稱, Optimize方法負責具體的優化操作. Feedback方法.
目的: 將所有不會被執行的節點都裁剪掉. 也就是那些不會被fanin的節點.如果沒有指定fetch節點,將假設整個圖都將被執行. 
不能移除必須被保留的節點(在nodes_to_prserve中);
不能移除驅動control依賴的節點;
不能移除無法確定移除後是否會新增control依賴的節點(比如,移除一個10條control edge同時驅動10條control edge,將新建100條edge);
不能移除與function鏈接的節點,因爲會導致後面內聯失敗;
不能移除被其它設備驅動的節點,因爲使用這些節點能夠降低通信開銷;
不能移除接收引用值的節點,將引用轉換成非引用也不行(可能理解的不大對).

const folding

對圖中常量進行合併優化.遍歷圖中節點,找出完全能夠靜態計算的節點,也就是說完全依賴於constants輸入的.在CPU上將這些節點計算出來,並替換這些節點.沒有CPU KERNEL的op不能進行constant folding.
輔助類:
EigenThreadPoolWrapper: 對Eigen庫中threadpool進行封裝,提供Schedule選取一個線程執行特定的函數.
DeviceSimple: 繼承自DeviceBase

不能直接在switch節點上直接進行控制依賴的固定,因爲和其它節點不同的是,執行時switch節點之後產生一個輸出,並且我們必須確保控制依賴只有在對應的輸出被觸發時觸發.我們一開始是通過查找到一個是switch節點輸出節點關聯的identity節點,並將它作爲控制依賴的標定點.如果我們找不到這樣的節點,那麼就需要添加一個額外的identity節點.
輔助函數:
AddControlDependency函數:查找一個用於標定control dependency的節點,如果沒有,需要添加.
MaterializeShapes: 將shape或者size或者rank操作實質化,因爲計算圖中tensor的這些屬性是可以推導計算的.
IsFoldable: 如果一個節點的輸入是空的,是不支持fold的.跳過指定preserve(白名單的除外)的節點.跳過const類型的操作,因爲這些節點已經fold過了.跳過控制流節點.沒必要fold沒有出邊的節點,除了白名單的節點.這些節點會在前期的常量folding過程中處理,如果用戶想要取它們的值,那麼需要保留.不能重複進行處理(fold檢查並執行folding操作會出錯).如果一個節點的所有輸入都是靜態可知,除了一種特殊情況,比如一個合併節點,只有第一個輸入可用時,要求一個單獨的constant輸入可以被fold操作.(比較繞,具體還是建議大家看看代碼).暫時不支持對string constant,可以理解爲checkpoint時有bug.
EvaluateNode: 計算給定節點的輸出,給定節點的輸入,調用該opKernel->Compute函數.被EvaluateOneFoldable函數調用,根據輸出新建一個節點.

layout

將NHWC的內存佈局轉換到GPU相關的操作NCHW(主要和卷積相關,cudnn使用NCHW比較高效)
輔助類:
GraphProcessor:主要提供了三個往計算圖中添加const類型節點的方法(permutation/scalar/reduction,注意:三類詳細的區別還不是很清楚,均需要指定device)
NodeProcessor:繼承自GraphProcessor,三個成員方法updateAttrDataFormat(如果format爲NHWC,那麼設置爲NCHW)/UpdateAttrShape(將輸出的shape設置爲CHW)/updateAttrSize/updateAttrStrides/updateAttrValue/updateAttrValueOfInput.用於修改輸入輸出數據的shape,同時更新內部數據結構屬性值.AddNodeTranspose添加轉換節點到計算圖中.AddLayoutTransposeToInputs(調用AddNodeTranspose方法爲輸入添加layout transpose操作),AddLayoutTransposeToOutputs(爲輸出添加layout transpose).

AvgPoolGradProcessor: 繼承自NodeProcessor.
BiasAddGradProcessor:繼承自NodeProcessor.
Conv2DProcessor(stride爲1,如果大於一,那麼不進行layout轉換操作,是否爲有效padding)
Conv2DBackpropFilterProcessor:繼承自Conv2DProcesso
Conv2DBackpropInputProcessor
FusedBatchNormGradProcessor
MaxPoolGradProcessor
AgnosticNodeProcessor
AddNProcessor,BinaryOpProcessor,ConcatProcessor,ReluGradProcessor,SliceProcessor:繼承自AgnosticNodeProcessor
SliceProcessorConst:這個類主要是應對一種特殊情況,當第二三輸入均爲CONST時,首先會進行const folding操作,然後再進行slice優化.
SliceProcessorConcatOffset:當第二個輸入爲ConcatOffset.(比如inceptionV3中concat梯度計算)
SqueezeProcessor
SumProcessor
備註: conv2D,conv2DBackpropInput以及conv2DBackpropFilter,當filter size爲1,或者等於輸入image size時,NHWC的實現將採用特定的GEMM實現,通常來說會比NCHW的實現快.
DataLayoutOptimizer:繼承自GrapProcessor.執行時需要遍歷兩次所有的節點shape,第一次是擴展支持NCHW的節點. 第二次是擴展layout不可知的節點.(collapse函數是爲了合併所有的節點對,比如兩個節點均是transpose操作,而且相反,那麼可以合併.)

最後實現optimize方法,對圖完成LayoutOptimizer操作.(代碼實現中基於經驗觀察,如果引入的轉換節點個數超過30個,那麼不使用gemm的實現能夠獲得更好的性能)

Memory

主要目的: 將tensor從設備內存中換入換出.
構造函數指定優化級別(autonomy級別爲memory optimizer),提到rewriterConfig,指定recomputation_targets_name_prefix以及memory_optimizer_target_node_name_prefix.

備註:這裏的內存優化策略是將forward的部分op結果swap到host-memory,然後計算backward gradients時重新計算該op,達到節省顯存的目的.或者標定一些節點爲recomputed_node.

GetCheapToRecomputeOps方法返回一個op名稱的數組,標記爲這些操作爲輕量級可recompute的操作.目前的實現僅僅提供一些靜態的數組,後期可能會提供一個代價模型更加合理的op列表.
FindCandidateRecomputeNodes方法:找出所有feed給目標節點的recomputable ops.
connected_subgraph: 爲candidateRecommputable 節點生成連接圖.
GetOpGroupsToRecompute方法:基於should_recompute方法,找出幾組op一起recompute.返回一個需要recompute的一組子圖.
GetMaxDownstreamComponents:計算最大的拓撲數量(1,目標節點的組成,即梯度節點,feed by recomputation),(2,每個recomputed node的子重計算節點,) 當componet的數量大於這個值的時候,需要爲一個重計算添加一個控制依賴.
AddRecomputeControlDependencyNodes:修改計算圖,添加觸發器節點,返回一個recomputed_source_nodes到trigger nodes的映射.

BuildSwapPair方法: 創建swap-in/swap-out節點對,
FindSwapTrigger方法: max_trigger_time存儲了swap操作需要提前執行,將數據載入回到加速卡上,同時不影響下游計算的時間.也就是swap操作需要提前執行.

Optimize方法: 1. 找出所有_swap_to_host的節點 2. 評估每個節點需要swap的數據大小,以及傳輸時間(假設是基於PCIE 16GBps). 3. 遍歷swap節點,找出swap trigger,找出我們需要將數據交換回來之前的節點執行,並且添加一個從這個節點到swap節點的控制依賴. 4. 將屬性標記爲swap_to_host的節點所有的tensor交換出去.同時添加必要的控制依賴用於延遲swap操作的執行.

備註:從代碼邏輯來看,是將target節點的輸入進行swap,重計算target的值.

auto parallel

主要操作: 自動並行一個圖,通過將batch維度進行切分.可以理解爲數據並行.
根據可用的gpu數量添加replica節點以及shared節點.

arithmetic

主要操作: 通過降低數值計算的複雜度來優化TF計算.
對數值表達式進行簡化,移除冗餘計算,表達式替換等手段.

以上內容僅僅很粗略的閱讀了一些代碼,後面會不斷細化.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章