轉換Onnx過程中
1 pytorch轉onnx模型多輸入問題(如:Bert)
Bert模型有三個輸入,因此就要創建三個dummy_input,然後利用一個tuple,傳入函數中。
dummy_input0 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input1 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input2 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
torch.onnx.export(model. (dummy_input0, dummy_input1, dummy_input2), filepath)
2
PyTorch v1.0.1 Reshape不支持報錯 [Solution]
PyTorch v1.2.0 需要升級cuda10.0以上
3 像data[index] = new_data這樣的張量就地索引分配目前在導出中不受支持。解決這類問題的一種方法是使用算子散點,顯式地更新原始張量。
就是像tensorflow的靜態圖,不能隨便改變tensor的值,可以用torch的scatter_方法解決
錯誤的方式
# def forward(self, data, index, new_data):
# data[index] = new_data # 重新賦值
# return data
正確的方式
class InPlaceIndexedAssignmentONNX(torch.nn.Module):
def forward(self, data, index, new_data):
new_data = new_data.unsqueeze(0)
index = index.expand(1, new_data.size(1))
data.scatter_(0, index, new_data)
return data
4. ONNX export failed on ATen operator group_norm because torch.onnx.symbolic.group_norm does not exist
解決:~/anaconda3/envs/py36/lib/python3.6/site-packages/torch/onnx/symbolic.py
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
return g.op("ATen", input, weight, bias, num_groups_i=num_groups,
eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm")
5 ‘RuntimeError: ONNX export failed: Couldn’t export operator aten::upsample_bilinear2d’
解決方案
略略略
6 RuntimeError: ONNX export failed: Couldn’t export operator aten::avg_pool2d
#Error
self.global_average = nn.AdaptiveAvgPool2d((1,1))//就是這一行的問題是用的AdaptiveAvgPool2d
#OK
self.global_average = nn.AvgPool2d(kernel_size = (7,7),stride=(7,7),ceil_mode=False)
以後遇到別人代碼使用Adaptive Pooling,可以通過這兩個公式轉換爲標準的Max/AvgPooling:
#只需要知道輸入的input_size ,就可以推導出stride 與kernel_size ,從而替換爲標準的Max/AvgPooling
stride = floor ( (input_size / (output_size) )
kernel_size = input_size − (output_size−1) * stride
padding = 0
PyTorch–>ONNX
這一部分比較簡單,大致照着PyTorch官網的例程走即可。
pytorch導出onnx並檢查
onnx_model = onnx.load(output_onnx)
onnx.checker.check_model(onnx_model) # assuming throw on error
print("==> Passed")
[Solution]
解決:~/anaconda3/envs/py36/lib/python3.6/site-packages/torch/onnx/symbolic.py
在該文件中添加代碼
def reshape(g, self, shape):
return view(g, self, shape)
def reshape_as(g, self, other):
shape = g.op('Shape', other)
return reshape(g, self, shape)
Reference
4 RuntimeError: ONNX export failed: Couldn't export operator aten::avg_pool2d