ONNX內部節點修改方法
承接上回《PyTorch轉ONNX之F.interpolate》,因爲op10
的計算輸出大小
問題,導致我上採樣的結果的大小出現小數,由預期輸出結果output_size=[1., 3., 9., 9.]
變成了output_size=[1., 3., 8.999, 8.999]
,經過後續強制轉換操作抹平成爲了output_size=[1, 3, 8, 8]
,這就很氣了。
如下圖所示,輸入大小爲input_size=[1, 3, 5, 5]
,scales爲[1, 1, 1.799, 1.799]
,根據input_size x scales = output_size
,輸出大小應爲output_size=[1., 3., 8.996, 8.996]
,按之前的描述,後續操作會向下取整得到實際輸出大小output_size=[1, 3, 8, 8]
。那麼,如果我將scales人工修改爲[1, 1, 1.801, 1.801]
不就可以避開這個問題了嗎。因此,接下來的問題就是,如何修改ONNX的內部節點。
1. 載入ONNX文件
import onnx
onnx_model = onnx.load("test.onnx")
graph = onnx_model.graph
node = graph.node
for i in range(len(node)):
print(node[i])
我們可以依靠上述代碼輸出該模型的節點個數,還有節點中的屬性信息,當然也包含靜態圖的鏈路形狀。
2. 搜索目標節點
接着,依據節點ID找到我們需要修改的Resize節點,這裏需要注意的是,Netron可視化出來的id需要經過轉換纔可以得到ONNX的實際ID,就像相對路徑之於絕對路徑一樣,爲了方便,這裏就推薦直接將打印出來的節點信息拷貝出來,進行關鍵字查找。
比如這裏,我的Resize輸出id爲450,那麼就用450作爲關鍵字進行搜索,得到下圖結果。
看起來,這個449就是對應的scales的節點onnx.Constant
,所以按照下列代碼,將這個節點的真實ID搜索出來,得到的結果是i=157
for i in range(len(node)):
if node[i].op_type == 'Constant':
node_rise = node[i]
if node_rise.output[0] == '449':
print(i) # 157
我們就可以直接使用node[157]
直接訪問這個節點了。
3. 修改目標節點
簡單來說,就像鏈表的插入操作一樣,即是刪除、新建、插入。如下列代碼所示:
old_scale_node = node[157]
new_scale_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=['449'],
value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [4], [1, 1, 1.81, 1.81])
) # 新建新節點
graph.node.remove(old_scale_node) # 刪除舊節點
graph.node.insert(157, new_scale_node) # 插入新節點
具體onnx.helper.make_node
的使用方法,可以去github上查找doc,然後就可以愉快地隨意修改ONNX模型了。
4. 檢查圖與保存
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'out.onnx')
可以看到,已成功修改。
完整代碼
import onnx
onnx_model = onnx.load("test.onnx")
graph = onnx_model.graph
node = graph.node
old_scale_node = node[157]
new_scale_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=['449'],
value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [4], [1, 1, 1.81, 1.81])
)
graph.node.remove(old_scale_node)
graph.node.insert(157, new_scale_node)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'out.onnx')