本文源於學習TensorRT文檔《TensorRT-Developer-Guide》第4章“EXTENDING TENSORRT WITH CUSTOM LAYERS”的理解。
通過C++API添加自定義層
自定義層添加是通過擴展IPluginV2Ext和IPluginCreator類來實現:
- IPluginV2Ext:IPluginV2的升級版,實現自定義插件的基類,包含版本化和對其它格式和單精度的處理;
- IPluginCreator:自定義層的創建類,可以通過它獲取插件的名稱、版本信息、參數等,也提供網絡創建階段創建插件的方法,並在推理階段反序列化它。
對定義好的插件可以通過REGISTER_TENSORRT_PLUGIN(pluginCreator)
進行靜態註冊,並在使用時通過getPluginRegistry()
查詢並使用。官方已經實現的插件有:
- RPROI_TRT
- Normalize_TRT
- PriorBox_TRT
- GridAnchor_TRT
- NMS_TRT
- LReLU_TRT
- Reorg_TRT
- Region_TRT
- Clip_TRT
// 通過getPluginRegistry獲取所有TensorRT插件,creator即IPluginCreator對象
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();
// 填充該層參數信息,pluginData需要先通過PluginField分配堆上空間
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);
// 使用層名和插件參數創建新的插件對象,創建在堆上,需要主動釋放
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);
// 在網絡上添加一層,並將該層和插件綁定,layer即IPluginV2Layer對象
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);
// TODO:創建最新的網絡,並序列化引擎
// 銷燬插件對象
pluginObj->destroy()
// TODO:釋放TensorRT資源,network、engine、builder
// TODO:釋放顯存空間,如原網絡參數信息pluginData
TensorRT的引擎會在序列化時內部存儲IPluginV2插件的屬性信息,並在反序列化時通過插件註冊表進行查找,並通過IPluginV2::destroy()接口內部銷燬。
過去的版本中,用戶必須通過nvinfer1::IPluginFactory類在反序列化時創建插件,現在的TensorRT版本可以使用addPluginV2即可。例如:
// 使用Caffe解釋器解析網絡並添加插件
// 如果使用IPluginExt創建插件,需要搭配nvinfer1::IPluginFactory 和 nvinfer1::IPluginFactory
class FooPlugin : public IPluginExt
{
// TODO:創建插件實現方法
};
class MyPluginFactory :
public nvinfer1::IPluginFactory,
public nvcaffeparser1::IPluginFactoryExt
{
// TODO:創建插件的工廠方法
};
// 如果使用IPluginV2創建並註冊插件,則不再需要實現nvinfer1::IPluginFactory,
// 但需要通過nvcaffeparser1::IPluginFactoryV2 和 IPluginCreator來完成註冊
class FooPlugin : public IPluginV2
{
// TODO:創建插件實現方法
};
class FooPluginFactory : public nvcaffeparser1::IPluginFactoryV2
{
virtual nvinfer1::IPluginV2* createPlugin(...)
{
// TODO:創建並返回插件對象,如FooPlugin
}
bool isPlugin(const char* name)
{
// TODO:通過網絡層的名字檢驗是否使用該插件
}
}
class FooPluginCreator : public IPluginCreator
{
// TODO:實現所有的插件創建
};
REGISTER_TENSORRT_PLUGIN(FooPluginCreator);
具體的插件創建實例可以查看:
- samplePlugin:自定義Caffe網絡插件方法;
- sampleFasterRCNN:通過TensorRT註冊Caffe網絡插件;
- sampleUffSSD:對UFF(針對TensorFlow)添加插件。
使用自定義插件
該部分內容基本與創建時介紹的情況雷同,需要注意的是對於Caffe解釋器,可以通過setPluginFactoryV2 和 IPluginFactoryV2使用自定義插件,那麼在反序列化時創建的插件會按照 IPluginExt::destroy()中定義的內容內部銷燬而無需手動調用,用戶只需要銷燬創建創建過程中的插件對象。
API描述
IPluginV2的API
1、獲取插件輸出數據結構,檢驗是否可以和相鄰層對接:
- getNbOutputs:驗證輸出張量數目;
- getOutputDimensions:驗證輸入維度,獲取輸出維度;
- supportsFormat:設置插件支持的數據類型,如何種處理精度;
- getOutputDataType:插件輸出數據的類型(NCHW、NC/2HW2 、NHWC8等,見PluginFormatType)。
2、獲取插件除了輸入輸出外,需要佔用多大的空間存儲數據,在builder中調用並預分配:
- getWorkspaceSize
3、插件在創建階段會多次配置、初始化、執行、中止,而運行時只會多次執行,配置、初始化、中止只執行一次,initialize申請的內存需要在terminate時被釋放,其它的內存需要在destroy釋放,所需要的插件爲:
- configurePlugin:配置輸入輸出屬性(數量、維度、類型、廣播、格式選擇、最大BatchSize),插件會選擇最合適的算法和數據結構;
- initialize:在插件配置和推理引擎創建之後使用,根據設置的數據結構配置並準備執行;
- enqueue:插件實際處理過程,需輸入運行BatchSize、輸入指針、輸出指針、緩存空間指針、CUDA流;
- terminate:在引擎的上下文被釋放時釋放插件的所有資源;
- clone:在需要一個獨立插件時(新的builder、network、engine被創建)使用;
- destroy:在builder、network、engine銷燬時調用,釋放對應的插件資源;
- set/getPluginNamespace:設置或獲取插件的命名空間,默認爲""(空)。
4、通過IPluginV2Ext可以實現輸入輸出的廣播性質,需要實現:
- canBroadcastInputAcrossBatch:判斷輸入張量是否可以在批中進行廣播,能廣播則返回true,TensorRT不會複製輸入並使用同一輸入副本;不能廣播返回false,TensorRT會複製輸入張量;
- isOutputBroadcastAcrossBatch:指定索引的輸出是否被廣播。
IPluginCreator的API
IPluginCreator中用來從插件庫中查找並創建插件的方法:
- getPluginName:獲取插件的名字,並和getPluginType配合使用;
- getPluginVersion:返回插件版本,TensorRT內部插件默認爲1;
- getFieldNames:返回PluginFieldCollection結構數據,包含添加插件的參數名和類型;
- createPlugin:通過給定的PluginFieldCollection結構參數創建插件,需填充實際所需參數;
- deserializePlugin:在TensorRT引擎根據插件名和版本信息內部調用,返回用於推理的插件對象;
- set/getPluginNamespace:creator所在的插件庫命名空間,默認爲""(空)。
從5.x.x遷移到5.1.x
5.x.x版本中沒有getOutputDataType、isOutputBroadcastAcrossBatch、canBroadcastInputAcrossBatch,configurePlugin是針對configureWithFormat的升級。在遷移到5.1.x時需要實現這些新特性。
virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const = 0;
virtual bool canBroadcastInputAcrossBatch(int inputIndex) const = 0;
virtual void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) = 0;