這篇文章主要介紹瞭如何使用TensorRT實現自定義算子。
Note:
- 我使用的是TensorRT7.0,自定義算子使用的IPluginV2IOExt實現的。
- 模型框架是caffe,所以以下實現都只適用於caffe模型的解析,但理論上解析tf和onnx的改動不大。
- 實現細節不方便全部貼出,但是基本實現過程和結構都在下面了,照着寫寫沒啥問題了。
其實自定義算子寫多了發現其實還挺好寫的,格式都差不多,主要區別是enqueue的前向計算邏輯可能寫起來複雜些。
整個實現過程基本上是:
- 繼承nvinfer1::IPluginV2IOExt,並實現相應的虛函數。
- 繼承nvinfer1::IPluginCreator並實現相應的虛函數。
- 繼承nvcaffeparser1::IPluginFactoryV2並實現相應的虛函數。
- 在解析網絡之前調用REGISTER_TENSORRT_PLUGIN註冊UpsampleCreator和調用parser->setPluginFactoryV2()以使用自定義層類型。
以Upsample爲例,TensorRT不支持Caffe的Upsample層,所以這裏實現了一個自定義層類型,即plugin。需要實現:
- Upsample類,繼承自nvinfer1::IPluginV2IOExt。
- UpsampleCreator類,繼承自nvinfer1::IPluginCreator。
- CaffePluginFactory類,繼承自nvcaffeparser1::IPluginFactoryV2。
需要實現的函數詳見如下代碼段。
Upsample類的實現:
class Upsample : public nvinfer1::IPluginV2IOExt {
public:
// 直接解析網絡時候需要用到
Upsample();
// 反序列化時候需要用到
Upsample(const void *data, size_t length);
~Upsample();
// 直接return輸出節點數,
int getNbOutputs() override;
// return輸出的維度信息,如:return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
Dims getOutputDimensions(int index, const Dims *inputs, int num_input_dims) override;
// pos索引到的input/output的數據格式(format)和數據類型(datatype)如果都支持則返回true
bool supportsFormatCombination(int pos, const PluginTensorDesc* in_out, int num_inputs, int num_outputs) const override;
// 這個函數可以獲取到數據類型和輸入的維度信息,如果有需要用到的可以在這裏將相關信息取出來
configurePlugin(const PluginTensorDesc* in, int num_inputs, const PluginTensorDesc* out, int num_outputs) override;
// 在這裏返回正確的序列化數據的長度,如我要序列化數據類型和數據維度:return sizeof(data_type) + sizeof(chw);
size_t getSerializationSize() const override;
// 序列化函數,在這裏把反序列化時需要用到的參數或數據序列化
void serialize(void *buffer) const override;
// 設置工作空間,不需要直接 return 0;
size_t getWorkspaceSize(int max_batch_size) const override;
// 前向計算的核心函數,計算邏輯在這裏實現,可以使用cublas實現或者自己寫cuda核函數實現
int enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override;
// 調用enqueue的時候需要用到的資源先在這裏Initialize,這個函數是在engine創建之後enqueue調用之前調用的,不需要Initialize則直接 return 0;
int initialize() override;
// 釋放Initialize申請的資源,在enqueue調用之後且engine銷燬之後調用
void terminate() override;
// 返回輸出的數據類型,如何輸入相同,可以直接 return input_types[0];
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* input_types, int num_inputs) const override;
// 返回自定義類型,如這裏是:return Upsample
const char* getPluginType() const override;
// 返回plugin version,沒啥說的
const char* getPluginVersion() const override;
// 銷燬對象
void destroy() override {
delete this;
}
// 在這裏new一個該自定義類型並返回
nvinfer1::IPluginV2Ext* clone() const override;
// 設置命名空間,用來在網絡中查找和創建plugin
void setPluginNamespace(const char* lib_namespace) override;
// 返回plugin對象的命名空間
const char* getPluginNamespace() const override;
bool isOutputBroadcastAcrossBatch(int output_index, const bool* input_is_broadcasted, int num_inputs) const override;
bool canBroadcastInputAcrossBatch(int input_index) const override;
}
下面是對應的Creator類的實現:
class UpsampleCreator : public nvinfer1::IPluginCreator {
public:
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const PluginFieldCollection* getFieldNames() override;
// 創建自定義層pluin的對象並返回
nvinfer1::IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override;
// 創建自定義層pluin的對象並返回,反序列化用到
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serial_data, size_t serial_length) override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
}
下面是對應的plugin factory類的實現:
class CaffePluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
public:
// 在這裏判斷一個層是否爲自定義層類型
bool isPluginV2(const char* name) override;
// 在這裏創建自定義層類型的對象並返回
nvinfer1::IPluginV2* createPlugin(const char* layer_name, const nvinfer1::Weights* weights, int num_weights, const char* libNamespace="") override;
}
如有問題可加公衆號交流:AI算法愛好者