pytorch android部署 demo 用自己訓練的自定義模型踩坑記錄

記錄一個用自己定義的模型(一個稍微改了分類數目的vgg網絡,分40類)加到github項目裏面時遇到的小坑:

2021-01-26 19:02:42.191 19212-19370/org.pytorch.demo E/AndroidRuntime: FATAL EXCEPTION: ModuleActivity
    Process: org.pytorch.demo, PID: 19212
    com.facebook.jni.CppException: 
    
    aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor):
    Expected at most 12 arguments but found 13 positional arguments.
    :
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py(419): _conv_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py(423): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/container.py(117): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/garbage_classify/mycode/vgg.py(42): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/garbage_classify/mycode/myVGG.py(25): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/jit/_trace.py(934): trace_module
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/jit/_trace.py(733): trace
    /home/xutengfei/garbage_classify/mycode/deployment_script.py(32): <module>
    Serialized   File "code/__torch__/torch/nn/modules/conv.py", line 10
        input: Tensor) -> Tensor:
        _0 = self.bias
        input0 = torch._convolution(input, self.weight, _0, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
                 ~~~~~~~~~~~~~~~~~~ <--- HERE
        return input0
    
        at org.pytorch.NativePeer.initHybrid(Native Method)
        at org.pytorch.NativePeer.<init>(NativePeer.java:24)
        at org.pytorch.Module.load(Module.java:23)
        at org.pytorch.demo.vision.ImageClassificationActivity.analyzeImage(ImageClassificationActivity.java:166)
        at org.pytorch.demo.vision.ImageClassificationActivity.analyzeImage(ImageClassificationActivity.java:31)
        at org.pytorch.demo.vision.AbstractCameraXActivity.lambda$setupCameraX$2$AbstractCameraXActivity(AbstractCameraXActivity.java:90)
        at org.pytorch.demo.vision.-$$Lambda$AbstractCameraXActivity$t0OjLr-l_M0-_0_dUqVE4yqEYnE.analyze(Unknown Source:2)
        at androidx.camera.core.ImageAnalysisAbstractAnalyzer.analyzeImage(ImageAnalysisAbstractAnalyzer.java:57)
        at androidx.camera.core.ImageAnalysisNonBlockingAnalyzer$1.run(ImageAnalysisNonBlockingAnalyzer.java:135)
        at android.os.Handler.handleCallback(Handler.java:900)
        at android.os.Handler.dispatchMessage(Handler.java:103)
        at android.os.Looper.loop(Looper.java:219)
        at android.os.HandlerThread.run(HandlerThread.java:67)

抓住其中的錯誤提示:Expected at most 12 arguments but found 13 positional arguments.
仔細對照參數:發現確實多了一個true:

aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
"input", "self.weight", _0, "[1, 1]", "[1, 1]", "[1, 1]", "False", "[0, 0]", "1", "False, False, True", True 就是這個有問題了。

然後根據錯誤提示到網上查閱相關資料,推斷可能是版本問題。
之後果然在github的issue裏面找到了想要的答案:修改build.gradle裏面的pytorch-android爲最新版本即可!

implementation 'org.pytorch:pytorch_android:1.7.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.7.0'
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章