【環境搭建】onnxruntime

1,介紹
    onnxruntime是一個用於onnx模型推理的引擎。

2,安裝
2.1 onnxruntime

git clone https://github.com/microsoft/onnxruntime
cd onnxruntime
git submodule sync
git submodule update --init --recursive 
./build.sh \
	--use_cuda \
	--cuda_version=10.0 \
	--cuda_home=/usr/local/cuda \
	--cudnn_home=/usr/local/cuda \
	--use_tensorrt --tensorrt_home=$HOME/TensorRT \
	--build_shared_lib --enable_pybind \
	--build_wheel --update --build
pip build/Linux/Debug/dist/onnxruntime_gpu_tensorrt-1.3.0-cp36-cp36m-linux_x86_64.whl

2.2 cmake,版本>=3.12.0

sudo apt-get install libssl-dev
sudo apt-get autoremove cmake # 卸載
wget https://cmake.org/files/v3.17/cmake-3.17.3.tar.gz
tar -xf cmake-3.17.3
cd cmake-3.17.3
./bootstrap
make -j 8
sudo make install

cmake -version
cmake version 3.17.3
CMake suite maintained and supported by Kitware (kitware.com/cmake).

3,onnxruntime and pytorch inference

import torch
import torchvision
import numpy as np
import time
import onnxruntime

def model_run(model, input, description):
    t1 = time.time()
    out_pt = model(input)
    t2 = time.time()
    use_time = t2-t1
    print(description, "inference time:", use_time)
    return out_pt, use_time

onnx_model_name = "resnet18.onnx"

x = torch.rand(1, 3, 224, 224).float()
model_pt = torchvision.models.resnet18(pretrained=True).cuda()
model_pt.eval()
out_pt, use_time_pt = model_run(model_pt, x.cuda(), "pytorch resnet18")
out_pt = out_pt.data.cpu().numpy()
print("pytorch label:",out_pt.argmax(1))

# onnx
torch.onnx.export(model_pt, x.cuda(), onnx_model_name)

# onnx inference gpu
sess = onnxruntime.InferenceSession(onnx_model_name)

# input, output, 類似tensorflow
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

x = x.numpy().astype(np.float32)
t1 = time.time()
result = sess.run([output_name], {input_name:x})
t2 = time.time()
use_time_onnxruntime = t2 - t1
result = np.array(result[0])
print("onnxruntime label:",result.argmax(1))
print("onnxruntime inference time:", use_time_onnxruntime)

mse = np.mean((result - out_pt)**2)
print("mse:", mse)
print("multiple:",use_time_pt / use_time_onnxruntime)

輸出結果

pytorch resnet18 inference time: 0.14845037460327148
pytorch label: [111]
onnxruntime label: [111]
onnxruntime inference time: 0.007770538330078125
mse: 1.0556704e-12
multiple: 19.104258713794795
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章