TensorRT學習(三)通過自定義層擴展TensorRT

  本文源於學習TensorRT文檔《TensorRT-Developer-Guide》第4章“EXTENDING TENSORRT WITH CUSTOM LAYERS”的理解。

通過C++API添加自定義層

  自定義層添加是通過擴展IPluginV2Ext和IPluginCreator類來實現:

  1. IPluginV2Ext:IPluginV2的升級版,實現自定義插件的基類,包含版本化和對其它格式和單精度的處理;
  2. IPluginCreator:自定義層的創建類,可以通過它獲取插件的名稱、版本信息、參數等,也提供網絡創建階段創建插件的方法,並在推理階段反序列化它。

  對定義好的插件可以通過REGISTER_TENSORRT_PLUGIN(pluginCreator)進行靜態註冊,並在使用時通過getPluginRegistry()查詢並使用。官方已經實現的插件有:

  • RPROI_TRT
  • Normalize_TRT
  • PriorBox_TRT
  • GridAnchor_TRT
  • NMS_TRT
  • LReLU_TRT
  • Reorg_TRT
  • Region_TRT
  • Clip_TRT
// 通過getPluginRegistry獲取所有TensorRT插件,creator即IPluginCreator對象
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();

// 填充該層參數信息,pluginData需要先通過PluginField分配堆上空間
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);

// 使用層名和插件參數創建新的插件對象,創建在堆上,需要主動釋放
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);

// 在網絡上添加一層,並將該層和插件綁定,layer即IPluginV2Layer對象
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);

// TODO:創建最新的網絡,並序列化引擎

// 銷燬插件對象
pluginObj->destroy() 

// TODO:釋放TensorRT資源,network、engine、builder
// TODO:釋放顯存空間,如原網絡參數信息pluginData 

  TensorRT的引擎會在序列化時內部存儲IPluginV2插件的屬性信息,並在反序列化時通過插件註冊表進行查找,並通過IPluginV2::destroy()接口內部銷燬。
  過去的版本中,用戶必須通過nvinfer1::IPluginFactory類在反序列化時創建插件,現在的TensorRT版本可以使用addPluginV2即可。例如:

// 使用Caffe解釋器解析網絡並添加插件
// 如果使用IPluginExt創建插件,需要搭配nvinfer1::IPluginFactory 和 nvinfer1::IPluginFactory
class FooPlugin : public IPluginExt
{
	// TODO:創建插件實現方法
};
class MyPluginFactory : 
public nvinfer1::IPluginFactory, 
public nvcaffeparser1::IPluginFactoryExt
{
	// TODO:創建插件的工廠方法
};

// 如果使用IPluginV2創建並註冊插件,則不再需要實現nvinfer1::IPluginFactory,
// 但需要通過nvcaffeparser1::IPluginFactoryV2 和 IPluginCreator來完成註冊
class FooPlugin : public IPluginV2
{
	// TODO:創建插件實現方法
};
class FooPluginFactory : public nvcaffeparser1::IPluginFactoryV2
{
	virtual nvinfer1::IPluginV2* createPlugin(...)
	{
		// TODO:創建並返回插件對象,如FooPlugin
	}
	bool isPlugin(const char* name)
	{
		// TODO:通過網絡層的名字檢驗是否使用該插件
	}
}
class FooPluginCreator : public IPluginCreator
{
	// TODO:實現所有的插件創建
};
REGISTER_TENSORRT_PLUGIN(FooPluginCreator);

  具體的插件創建實例可以查看:

  • samplePlugin:自定義Caffe網絡插件方法;
  • sampleFasterRCNN:通過TensorRT註冊Caffe網絡插件;
  • sampleUffSSD:對UFF(針對TensorFlow)添加插件。

使用自定義插件

  該部分內容基本與創建時介紹的情況雷同,需要注意的是對於Caffe解釋器,可以通過setPluginFactoryV2 和 IPluginFactoryV2使用自定義插件,那麼在反序列化時創建的插件會按照 IPluginExt::destroy()中定義的內容內部銷燬而無需手動調用,用戶只需要銷燬創建創建過程中的插件對象。

API描述

IPluginV2的API

  1、獲取插件輸出數據結構,檢驗是否可以和相鄰層對接:

  • getNbOutputs:驗證輸出張量數目;
  • getOutputDimensions:驗證輸入維度,獲取輸出維度;
  • supportsFormat:設置插件支持的數據類型,如何種處理精度;
  • getOutputDataType:插件輸出數據的類型(NCHW、NC/2HW2 、NHWC8等,見PluginFormatType)。

  2、獲取插件除了輸入輸出外,需要佔用多大的空間存儲數據,在builder中調用並預分配:

  • getWorkspaceSize

  3、插件在創建階段會多次配置、初始化、執行、中止,而運行時只會多次執行,配置、初始化、中止只執行一次,initialize申請的內存需要在terminate時被釋放,其它的內存需要在destroy釋放,所需要的插件爲:

  • configurePlugin:配置輸入輸出屬性(數量、維度、類型、廣播、格式選擇、最大BatchSize),插件會選擇最合適的算法和數據結構;
  • initialize:在插件配置和推理引擎創建之後使用,根據設置的數據結構配置並準備執行;
  • enqueue:插件實際處理過程,需輸入運行BatchSize、輸入指針、輸出指針、緩存空間指針、CUDA流;
  • terminate:在引擎的上下文被釋放時釋放插件的所有資源;
  • clone:在需要一個獨立插件時(新的builder、network、engine被創建)使用;
  • destroy:在builder、network、engine銷燬時調用,釋放對應的插件資源;
  • set/getPluginNamespace:設置或獲取插件的命名空間,默認爲""(空)。

  4、通過IPluginV2Ext可以實現輸入輸出的廣播性質,需要實現:

  • canBroadcastInputAcrossBatch:判斷輸入張量是否可以在批中進行廣播,能廣播則返回true,TensorRT不會複製輸入並使用同一輸入副本;不能廣播返回false,TensorRT會複製輸入張量;
  • isOutputBroadcastAcrossBatch:指定索引的輸出是否被廣播。

IPluginCreator的API

  IPluginCreator中用來從插件庫中查找並創建插件的方法:

  • getPluginName:獲取插件的名字,並和getPluginType配合使用;
  • getPluginVersion:返回插件版本,TensorRT內部插件默認爲1;
  • getFieldNames:返回PluginFieldCollection結構數據,包含添加插件的參數名和類型;
  • createPlugin:通過給定的PluginFieldCollection結構參數創建插件,需填充實際所需參數;
  • deserializePlugin:在TensorRT引擎根據插件名和版本信息內部調用,返回用於推理的插件對象;
  • set/getPluginNamespace:creator所在的插件庫命名空間,默認爲""(空)。

從5.x.x遷移到5.1.x

  5.x.x版本中沒有getOutputDataType、isOutputBroadcastAcrossBatch、canBroadcastInputAcrossBatch,configurePlugin是針對configureWithFormat的升級。在遷移到5.1.x時需要實現這些新特性。

virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const = 0;
virtual bool canBroadcastInputAcrossBatch(int inputIndex) const = 0;
virtual void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) = 0;

  

  

  

  

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章