TensorFlow 源碼閱讀[1] OpKernel的註冊

OpKernel介紹

在這裏插入圖片描述

在TF的架構中,OpKernel是Ops和硬件的中間層,用來抽象統一各個硬件平臺上的Kernel類和接口。

註冊過程

我們首先大致列出OpKernel註冊的過程,後面再詳細分析,我們按照調用順序,從上層往下說:

  1. 在各個xxx_op.cc文件中調用REGISTER_KERNEL_BUILDER()
  2. 調用OpKernelRegistrar的構造函數
  3. 並在該構造函數中調用OpKernelRegistrar::InitInternal
  4. 調用GlobalKernelRegistry獲取保存註冊信息的map
  5. 將Key和kernel保存到map中

分析

現在我們來逐個分析,在上面我們是從調用過程往下走,在這裏,我們嘗試從底層往上走。

1.KernelRegistration

首先我們需要關注的是KernelRegistration類,它用來保存OpKernel註冊所需的信息,包括KernelDef、kernel的名字以及kernel的創建方法factory:

struct KernelRegistration {
  KernelRegistration(const KernelDef& d, StringPiece c,
                     std::unique_ptr<kernel_factory::OpKernelFactory> f)
      : def(d), kernel_class_name(c), factory(std::move(f)) {}

  const KernelDef def;
  const string kernel_class_name;
  std::unique_ptr<kernel_factory::OpKernelFactory> factory;
};

2.KernelRegistry

這個結構體用來保存OpKernel的註冊信息KernelRegistration,並將這些信息保存到一個unordered_multimap裏:

struct KernelRegistry {
  mutex mu;
  std::unordered_multimap<string, KernelRegistration> registry
      TF_GUARDED_BY(mu);
};

這個map維持一個Key到OpKernel註冊信息之間的關係,而這個Key,是這樣生成的:

const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());

既然是unordered_multimap,說明一個Key可以對應多個KernelRegistration。
KernelRegistry的實例是通過下面這個函數構造的:

void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = []() {
    KernelRegistry* registry = new KernelRegistry;
    OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
    return registry;
  }();
  return global_kernel_registry;
}

3.OpKernelRegistrar

上面我們提到了OpKernel需要保存的信息,以及這些信息是保存在一個unordered_multimap中的,下面我們要來看這個保存的過程。
我們首先來看這個類的構造函數:

// 構造函數1
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    std::unique_ptr<OpKernelFactory> factory) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name, std::move(factory));
    }
  }

//構造函數2
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    OpKernel* (*create_fn)(OpKernelConstruction*)) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name,
                   absl::make_unique<PtrOpKernelFactory>(create_fn));
    }
  }

這裏涉及到另外一個類OpKernelFactory,我們也可以看下它的定義:

class OpKernelFactory {
 public:
  virtual OpKernel* Create(OpKernelConstruction* context) = 0;
  virtual ~OpKernelFactory() = default;
};

從這個類的create函數我們就可以看出,OpKernelRegistrar的亮哥構造函數其實大同小異,第一個參數是kernel_del,第二個參數是kernel_class_name,第三個參數都是創建這個kernel的函數。
我們來看一下OpKernelRegistrar構造函數的核心部分:

void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                     StringPiece kernel_class_name,
                                     std::unique_ptr<OpKernelFactory> factory) {
  // See comments in register_kernel::Name in header for info on _no_register.
  if (kernel_def->op() != "_no_register") {
    const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());

	auto global_registry =
	        reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
	    mutex_lock l(global_registry->mu);
	    global_registry->registry.emplace(
	        key,
	        KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
	}
}

這個GlobalKernelRegistry我們之前已經說過了,它返回的是一個KernelRegistry實例,global_registry->registry 就是我們之前說的保存註冊信息的map,也就是說,OpKernel的註冊發生在OpKernelRegistrar的構造函數中!
我們順藤摸瓜,看看這個構造函數是怎麼被調用的。

4. REGISTER_KERNEL_BUILDER

OpKernelRegistrar的構造就是在REGISTER_KERNEL_BUILDER宏定義中:

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
  constexpr bool should_register_##ctr##__flag =                      \
      SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__##ctr##__object(                               \
          should_register_##ctr##__flag                               \
              ? ::tensorflow::register_kernel::kernel_builder.Build() \
              : nullptr,                                              \
          #__VA_ARGS__,                                               \
          [](::tensorflow::OpKernelConstruction* context)             \
              -> ::tensorflow::OpKernel* {                            \
            return new __VA_ARGS__(context);                          \
          });

宏定義理解起來往往比較麻煩,不要着急,我們一個個看。

首先做一些宏定義知識的補充,可能不是所有人都清楚(比如我-_-!):

__COUNTER__ 可以理解爲一個int型計數器,初始值爲0,每出現一次,值+1
#x 將x轉換成一個字符串
##ctr 變量拼接,就是將ctr的值拼接到整個變量中
__VA_ARGS__可變參數

有了上面這些知識,我們再來看這些宏就沒這麼複雜了:

  1. 首先REGISTER_KERNEL_BUILDER接受兩個參數,一個是kernel_builder,另一個是可變參數;
  2. 將這兩個參數傳給REGISTER_KERNEL_BUILDER_UNIQ_HELPER,而這個宏在前面的宏的基礎上,增加了一個計數器,並將這三個參數傳給下一個定義的宏
  3. REGISTER_KERNEL_BUILDER_UNIQ接受了這三個參數,然後定義一個臨時變量should_register_##ctr##__flag,根據我們上面宏定義的知識,ctr和flag的值都會拼接到register_後面,而這個bool值的結果是SHOULD_REGISTER_OP_KERNEL(#_VA_ARGS),看字面意思就可以理解爲是否需要註冊這個OpKernel;然後定義了一個static的OpKernelRegistrar變量registrar__body__##ctr##__object,且調用了OpKernelRegistrar的第二類構造函數:

至此我們找到了構造OpKernelRegistrar的地方,也就是說每次使用宏REGISTER_KERNEL_BUILDER註冊OpKernel,都會調用OpKernelRegistrar並將對應的Kernel信息存到map中。

  1. 我們看一下OpKernelRegistrar構造函數的參數:

1)should_register_##ctr##__flag ? ::tensorflow::register_kernel::kernel_builder.Build() : nullptr 也就是說如果需要創建這個OpKernel,就傳入::tensorflow::register_kernel::kernel_builder.Build()這個參數的值我們後面會介紹,根據構造函數的三個參數,我們暫時只需要知道這一長串會返回一個KernelDef對象
2) #__VA_ARGS__ 第二個參數是可變參數變成的字符串,也就是kernel_class_name
3)[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context);這是一個lamda表達式函數,入參數OpKernelConstruction* context,返回類型是OpKernel*,這個函數指針本身構成了第三個參數,即OpKernel* (*create_fn)(OpKernelConstruction*)

到此我們應該理解了這個複雜的宏REGISTER_KERNEL_BUILDER,只需要正確使用這個宏,就可以註冊一個OpKernel!!!
遺留了一個問題,就是爲什麼這個kernel_builder.Build(),就相當於是KernelDef對象呢?

5.如何使用這個宏?

我們看一下官方的例子:

REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),DummyKernel);

這裏我們看到第一個參數是Name("Test1").Device(tensorflow::DEVICE_CPU)這個東西爲什麼就是KernelDef呢?我們看一下這個Name究竟是什麼,說實話這個類不太好找:


class Name : public KernelDefBuilder {
 public:
  explicit Name(const char* op)
      : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};

原來這個Name類是繼承自KernelDefBuilder類,且在它的構造函數中,調用了基類的構造函數,傳入的是op的名字,我們再來看一下這個基類:

class KernelDefBuilder {
 public:
  explicit KernelDefBuilder(const char* op_name);
  ~KernelDefBuilder();
  KernelDefBuilder& Device(const char* device_type);
  template <typename T>
  KernelDefBuilder& AttrConstraint(const char* attr_name, gtl::ArraySlice<T> allowed);
  template <typename T>
  KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
  KernelDefBuilder& TypeConstraint(const char* attr_name,
gtl::ArraySlice<DataType> allowed);
  KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
  template <class T>
  KernelDefBuilder& TypeConstraint(const char* attr_name);
  KernelDefBuilder& HostMemory(const char* arg_name);
  KernelDefBuilder& Label(const char* label);
  KernelDefBuilder& Priority(int32 priority);
  const KernelDef* Build();
 private:
  KernelDef* kernel_def_;
  TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};

基類KernelDefBuilder也接受一個op_name作爲構造參數,且我們現在可以看到,剛纔Name(“Test1”)後面的.Device()實際上就是KernelDefBuilder的成員函數,返回的是KernelDefBuilder&類型。

在得到這個KernelDefBuilder&類型的返回值後,在通過調用kernel_builder.Build()方法,就得到了const KernelDef* 類型的返回值,這就回答了我們剛纔的問題!

總結

我們花了很久的時間,就是爲了搞清楚TF究竟是如何設計和實現Opkernel的註冊的。我們先是簡單介紹了從調用到底層實現,然後詳細的從底層開始分析了每一步的實現。不得不說TF這一套東西很複雜,但是隻要多看兩遍,也可以理解。

對於OpKernel類來說,往下有它自身的數據類和數據管理類,以及構造輔助類,往上被封裝到一個宏定義中,在後面說到Op的時候,會發現整體思路和OpKernel十分相似,所以理解其中一個,另一個理解起來是水到渠成。

參考

  1. TF源碼
  2. 『深度長文』Tensorflow代碼解析(三)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章