ONNX內部節點修改方法

ONNX內部節點修改方法

承接上回《PyTorch轉ONNX之F.interpolate》,因爲op10計算輸出大小問題,導致我上採樣的結果的大小出現小數,由預期輸出結果output_size=[1., 3., 9., 9.]變成了output_size=[1., 3., 8.999, 8.999],經過後續強制轉換操作抹平成爲了output_size=[1, 3, 8, 8],這就很氣了。

如下圖所示,輸入大小爲input_size=[1, 3, 5, 5],scales爲[1, 1, 1.799, 1.799],根據input_size x scales = output_size,輸出大小應爲output_size=[1., 3., 8.996, 8.996],按之前的描述,後續操作會向下取整得到實際輸出大小output_size=[1, 3, 8, 8]。那麼,如果我將scales人工修改爲[1, 1, 1.801, 1.801]不就可以避開這個問題了嗎。因此,接下來的問題就是,如何修改ONNX的內部節點。

在這裏插入圖片描述

1. 載入ONNX文件

import onnx

onnx_model = onnx.load("test.onnx")
graph = onnx_model.graph
node  = graph.node

for i in range(len(node)):
	print(node[i])

我們可以依靠上述代碼輸出該模型的節點個數,還有節點中的屬性信息,當然也包含靜態圖的鏈路形狀。

2. 搜索目標節點

接着,依據節點ID找到我們需要修改的Resize節點,這裏需要注意的是,Netron可視化出來的id需要經過轉換纔可以得到ONNX的實際ID,就像相對路徑之於絕對路徑一樣,爲了方便,這裏就推薦直接將打印出來的節點信息拷貝出來,進行關鍵字查找。

比如這裏,我的Resize輸出id爲450,那麼就用450作爲關鍵字進行搜索,得到下圖結果。

在這裏插入圖片描述

看起來,這個449就是對應的scales的節點onnx.Constant,所以按照下列代碼,將這個節點的真實ID搜索出來,得到的結果是i=157

for i in range(len(node)):
    if node[i].op_type == 'Constant':
        node_rise = node[i]
        if node_rise.output[0] == '449':
            print(i)  # 157

我們就可以直接使用node[157]直接訪問這個節點了。

3. 修改目標節點

簡單來說,就像鏈表的插入操作一樣,即是刪除、新建、插入。如下列代碼所示:

old_scale_node = node[157]
new_scale_node = onnx.helper.make_node(
    "Constant",
    inputs=[],
    outputs=['449'],
    value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [4], [1, 1, 1.81, 1.81])
)  # 新建新節點
graph.node.remove(old_scale_node)  # 刪除舊節點
graph.node.insert(157, new_scale_node)  # 插入新節點

具體onnx.helper.make_node的使用方法,可以去github上查找doc,然後就可以愉快地隨意修改ONNX模型了。

4. 檢查圖與保存

onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'out.onnx')

可以看到,已成功修改。
在這裏插入圖片描述

完整代碼

import onnx

onnx_model = onnx.load("test.onnx")
graph = onnx_model.graph
node  = graph.node

old_scale_node = node[157]
new_scale_node = onnx.helper.make_node(
    "Constant",
    inputs=[],
    outputs=['449'],
    value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [4], [1, 1, 1.81, 1.81])
)
graph.node.remove(old_scale_node)  
graph.node.insert(157, new_scale_node) 

onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'out.onnx')
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章