OpKernel介紹
在TF的架構中,OpKernel是Ops和硬件的中間層,用來抽象統一各個硬件平臺上的Kernel類和接口。
註冊過程
我們首先大致列出OpKernel註冊的過程,後面再詳細分析,我們按照調用順序,從上層往下說:
- 在各個
xxx_op.cc
文件中調用REGISTER_KERNEL_BUILDER()
- 調用
OpKernelRegistrar
的構造函數 - 並在該構造函數中調用
OpKernelRegistrar::InitInternal
- 調用
GlobalKernelRegistry
獲取保存註冊信息的map - 將Key和kernel保存到map中
分析
現在我們來逐個分析,在上面我們是從調用過程往下走,在這裏,我們嘗試從底層往上走。
1.KernelRegistration
首先我們需要關注的是KernelRegistration類,它用來保存OpKernel註冊所需的信息,包括KernelDef、kernel的名字以及kernel的創建方法factory:
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
std::unique_ptr<kernel_factory::OpKernelFactory> f)
: def(d), kernel_class_name(c), factory(std::move(f)) {}
const KernelDef def;
const string kernel_class_name;
std::unique_ptr<kernel_factory::OpKernelFactory> factory;
};
2.KernelRegistry
這個結構體用來保存OpKernel的註冊信息KernelRegistration,並將這些信息保存到一個unordered_multimap裏:
struct KernelRegistry {
mutex mu;
std::unordered_multimap<string, KernelRegistration> registry
TF_GUARDED_BY(mu);
};
這個map維持一個Key到OpKernel註冊信息之間的關係,而這個Key,是這樣生成的:
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
既然是unordered_multimap,說明一個Key可以對應多個KernelRegistration。
KernelRegistry的實例是通過下面這個函數構造的:
void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = []() {
KernelRegistry* registry = new KernelRegistry;
OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
return registry;
}();
return global_kernel_registry;
}
3.OpKernelRegistrar
上面我們提到了OpKernel需要保存的信息,以及這些信息是保存在一個unordered_multimap中的,下面我們要來看這個保存的過程。
我們首先來看這個類的構造函數:
// 構造函數1
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
std::unique_ptr<OpKernelFactory> factory) {
// Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name, std::move(factory));
}
}
//構造函數2
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
OpKernel* (*create_fn)(OpKernelConstruction*)) {
// Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name,
absl::make_unique<PtrOpKernelFactory>(create_fn));
}
}
這裏涉及到另外一個類OpKernelFactory,我們也可以看下它的定義:
class OpKernelFactory {
public:
virtual OpKernel* Create(OpKernelConstruction* context) = 0;
virtual ~OpKernelFactory() = default;
};
從這個類的create函數我們就可以看出,OpKernelRegistrar的亮哥構造函數其實大同小異,第一個參數是kernel_del,第二個參數是kernel_class_name,第三個參數都是創建這個kernel的函數。
我們來看一下OpKernelRegistrar構造函數的核心部分:
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
std::unique_ptr<OpKernelFactory> factory) {
// See comments in register_kernel::Name in header for info on _no_register.
if (kernel_def->op() != "_no_register") {
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
auto global_registry =
reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
mutex_lock l(global_registry->mu);
global_registry->registry.emplace(
key,
KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
}
}
這個GlobalKernelRegistry我們之前已經說過了,它返回的是一個KernelRegistry實例,global_registry->registry 就是我們之前說的保存註冊信息的map,也就是說,OpKernel的註冊發生在OpKernelRegistrar的構造函數中!
我們順藤摸瓜,看看這個構造函數是怎麼被調用的。
4. REGISTER_KERNEL_BUILDER
OpKernelRegistrar的構造就是在REGISTER_KERNEL_BUILDER宏定義中:
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
constexpr bool should_register_##ctr##__flag = \
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
should_register_##ctr##__flag \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); \
});
宏定義理解起來往往比較麻煩,不要着急,我們一個個看。
首先做一些宏定義知識的補充,可能不是所有人都清楚(比如我-_-!):
__COUNTER__
可以理解爲一個int型計數器,初始值爲0,每出現一次,值+1
#x
將x轉換成一個字符串
##ctr
變量拼接,就是將ctr的值拼接到整個變量中
__VA_ARGS__
可變參數
有了上面這些知識,我們再來看這些宏就沒這麼複雜了:
- 首先REGISTER_KERNEL_BUILDER接受兩個參數,一個是kernel_builder,另一個是可變參數;
- 將這兩個參數傳給REGISTER_KERNEL_BUILDER_UNIQ_HELPER,而這個宏在前面的宏的基礎上,增加了一個計數器,並將這三個參數傳給下一個定義的宏
- REGISTER_KERNEL_BUILDER_UNIQ接受了這三個參數,然後定義一個臨時變量
should_register_##ctr##__flag
,根據我們上面宏定義的知識,ctr和flag的值都會拼接到register_後面,而這個bool值的結果是SHOULD_REGISTER_OP_KERNEL(#_VA_ARGS),看字面意思就可以理解爲是否需要註冊這個OpKernel;然後定義了一個static的OpKernelRegistrar變量registrar__body__##ctr##__object
,且調用了OpKernelRegistrar的第二類構造函數:
至此我們找到了構造OpKernelRegistrar的地方,也就是說每次使用宏REGISTER_KERNEL_BUILDER註冊OpKernel,都會調用OpKernelRegistrar並將對應的Kernel信息存到map中。
- 我們看一下OpKernelRegistrar構造函數的參數:
1)should_register_##ctr##__flag ? ::tensorflow::register_kernel::kernel_builder.Build() : nullptr 也就是說如果需要創建這個OpKernel,就傳入
::tensorflow::register_kernel::kernel_builder.Build()
這個參數的值我們後面會介紹,根據構造函數的三個參數,我們暫時只需要知道這一長串會返回一個KernelDef對象
2)#__VA_ARGS__
第二個參數是可變參數變成的字符串,也就是kernel_class_name
3)[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context);
這是一個lamda表達式函數,入參數OpKernelConstruction* context
,返回類型是OpKernel*,這個函數指針本身構成了第三個參數,即OpKernel* (*create_fn)(OpKernelConstruction*)
到此我們應該理解了這個複雜的宏REGISTER_KERNEL_BUILDER
,只需要正確使用這個宏,就可以註冊一個OpKernel!!!
遺留了一個問題,就是爲什麼這個kernel_builder.Build()
,就相當於是KernelDef對象呢?
5.如何使用這個宏?
我們看一下官方的例子:
REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),DummyKernel);
這裏我們看到第一個參數是Name("Test1").Device(tensorflow::DEVICE_CPU)
這個東西爲什麼就是KernelDef呢?我們看一下這個Name究竟是什麼,說實話這個類不太好找:
class Name : public KernelDefBuilder {
public:
explicit Name(const char* op)
: KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};
原來這個Name類是繼承自KernelDefBuilder類,且在它的構造函數中,調用了基類的構造函數,傳入的是op的名字,我們再來看一下這個基類:
class KernelDefBuilder {
public:
explicit KernelDefBuilder(const char* op_name);
~KernelDefBuilder();
KernelDefBuilder& Device(const char* device_type);
template <typename T>
KernelDefBuilder& AttrConstraint(const char* attr_name, gtl::ArraySlice<T> allowed);
template <typename T>
KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
KernelDefBuilder& TypeConstraint(const char* attr_name,
gtl::ArraySlice<DataType> allowed);
KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
template <class T>
KernelDefBuilder& TypeConstraint(const char* attr_name);
KernelDefBuilder& HostMemory(const char* arg_name);
KernelDefBuilder& Label(const char* label);
KernelDefBuilder& Priority(int32 priority);
const KernelDef* Build();
private:
KernelDef* kernel_def_;
TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};
基類KernelDefBuilder也接受一個op_name作爲構造參數,且我們現在可以看到,剛纔Name(“Test1”)後面的.Device()實際上就是KernelDefBuilder的成員函數,返回的是KernelDefBuilder&類型。
在得到這個KernelDefBuilder&類型的返回值後,在通過調用kernel_builder.Build()方法,就得到了const KernelDef*
類型的返回值,這就回答了我們剛纔的問題!
總結
我們花了很久的時間,就是爲了搞清楚TF究竟是如何設計和實現Opkernel的註冊的。我們先是簡單介紹了從調用到底層實現,然後詳細的從底層開始分析了每一步的實現。不得不說TF這一套東西很複雜,但是隻要多看兩遍,也可以理解。
對於OpKernel類來說,往下有它自身的數據類和數據管理類,以及構造輔助類,往上被封裝到一個宏定義中,在後面說到Op的時候,會發現整體思路和OpKernel十分相似,所以理解其中一個,另一個理解起來是水到渠成。