TensorRT的自定義算子Plugin的實現

這篇文章主要介紹瞭如何使用TensorRT實現自定義算子。

Note:

  1. 我使用的是TensorRT7.0,自定義算子使用的IPluginV2IOExt實現的。
  2. 模型框架是caffe,所以以下實現都只適用於caffe模型的解析,但理論上解析tf和onnx的改動不大。
  3. 實現細節不方便全部貼出,但是基本實現過程和結構都在下面了,照着寫寫沒啥問題了。

其實自定義算子寫多了發現其實還挺好寫的,格式都差不多,主要區別是enqueue的前向計算邏輯可能寫起來複雜些。
整個實現過程基本上是:

  1. 繼承nvinfer1::IPluginV2IOExt,並實現相應的虛函數。
  2. 繼承nvinfer1::IPluginCreator並實現相應的虛函數。
  3. 繼承nvcaffeparser1::IPluginFactoryV2並實現相應的虛函數。
  4. 在解析網絡之前調用REGISTER_TENSORRT_PLUGIN註冊UpsampleCreator和調用parser->setPluginFactoryV2()以使用自定義層類型。

以Upsample爲例,TensorRT不支持Caffe的Upsample層,所以這裏實現了一個自定義層類型,即plugin。需要實現:

  1. Upsample類,繼承自nvinfer1::IPluginV2IOExt。
  2. UpsampleCreator類,繼承自nvinfer1::IPluginCreator。
  3. CaffePluginFactory類,繼承自nvcaffeparser1::IPluginFactoryV2。

需要實現的函數詳見如下代碼段。

Upsample類的實現:

class Upsample : public nvinfer1::IPluginV2IOExt {
public:
    // 直接解析網絡時候需要用到
    Upsample();
    // 反序列化時候需要用到
    Upsample(const void *data, size_t length);
    ~Upsample();
    
    // 直接return輸出節點數,
    int getNbOutputs() override;
    
    // return輸出的維度信息,如:return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
    Dims getOutputDimensions(int index, const Dims *inputs, int num_input_dims) override;
    
    // pos索引到的input/output的數據格式(format)和數據類型(datatype)如果都支持則返回true
    bool supportsFormatCombination(int pos, const PluginTensorDesc* in_out, int num_inputs, int num_outputs) const override;
    
    // 這個函數可以獲取到數據類型和輸入的維度信息,如果有需要用到的可以在這裏將相關信息取出來
    configurePlugin(const PluginTensorDesc* in, int num_inputs, const PluginTensorDesc* out, int num_outputs) override;

    // 在這裏返回正確的序列化數據的長度,如我要序列化數據類型和數據維度:return sizeof(data_type) + sizeof(chw);
    size_t getSerializationSize() const override;
    
    // 序列化函數,在這裏把反序列化時需要用到的參數或數據序列化
    void serialize(void *buffer) const override;
    
    // 設置工作空間,不需要直接 return 0;
    size_t getWorkspaceSize(int max_batch_size) const override;
    
    // 前向計算的核心函數,計算邏輯在這裏實現,可以使用cublas實現或者自己寫cuda核函數實現
    int enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override;
    
    // 調用enqueue的時候需要用到的資源先在這裏Initialize,這個函數是在engine創建之後enqueue調用之前調用的,不需要Initialize則直接 return 0;
    int initialize() override;
    
    // 釋放Initialize申請的資源,在enqueue調用之後且engine銷燬之後調用
    void terminate() override;
    
    // 返回輸出的數據類型,如何輸入相同,可以直接 return input_types[0];
    nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* input_types, int num_inputs) const override;
    
    // 返回自定義類型,如這裏是:return Upsample
    const char* getPluginType() const override;
    
    // 返回plugin version,沒啥說的
    const char* getPluginVersion() const override;
      
    // 銷燬對象
    void destroy() override {
        delete this;
    }
    
    // 在這裏new一個該自定義類型並返回
    nvinfer1::IPluginV2Ext* clone() const override;
    
    // 設置命名空間,用來在網絡中查找和創建plugin
    void setPluginNamespace(const char* lib_namespace) override;
    // 返回plugin對象的命名空間
    const char* getPluginNamespace() const override;
    bool isOutputBroadcastAcrossBatch(int output_index, const bool* input_is_broadcasted, int num_inputs) const override;
    bool canBroadcastInputAcrossBatch(int input_index) const override;
}

下面是對應的Creator類的實現

class UpsampleCreator : public nvinfer1::IPluginCreator {
public:
    const char* getPluginName() const override;
    const char* getPluginVersion() const override;
    const PluginFieldCollection* getFieldNames() override;
    // 創建自定義層pluin的對象並返回
    nvinfer1::IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override;
    // 創建自定義層pluin的對象並返回,反序列化用到
    nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serial_data, size_t serial_length) override;
    void setPluginNamespace(const char* lib_namespace) override;
    const char* getPluginNamespace() const override;
}

下面是對應的plugin factory類的實現

class CaffePluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
public:
    // 在這裏判斷一個層是否爲自定義層類型
    bool isPluginV2(const char* name) override;
    // 在這裏創建自定義層類型的對象並返回
    nvinfer1::IPluginV2* createPlugin(const char* layer_name, const nvinfer1::Weights* weights, int num_weights, const char* libNamespace="") override;
}

如有問題可加公衆號交流:AI算法愛好者
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章