最近在實驗中,需要用到tensorflow建立一個簡單的模型,但鑑於部分要求比較苛刻,不能直接使用其內置的layer,因此需要自定義一個layer類,這便涉及到了對__init__()
, build()
, call()
這三個函數的理解
先看官方手冊中使用了Layer中的這三個關鍵函數的一個簡單的實例:
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
layer = MyDenseLayer(10)
從直觀上理解,似乎__init__()
和build()
函數都在對Layer進行初始化,都初始化了一些成員函數,而call()
函數則是在該layer被調用時執行。
顯然,這三個函數都是從tf.keras.layers.Layer
處繼承而來的,那麼不妨看一下官方對這幾個函數作何解釋。
下圖爲tf.keras.layers.Layer
的官方文檔
簡單翻譯,就是說官方推薦凡是tf.keras.layers.Layer
的派生類都要實現__init__()
,build()
, call()
這三個方法
__init__()
:保存成員變量的設置
build()
:在call()
函數第一次執行時會被調用一次,這時候可以知道輸入數據的shape
。返回去看一看,果然是__init__()
函數中只初始化了輸出數據的shape
,而輸入數據的shape
需要在build()
函數中動態獲取,這也解釋了爲什麼在有__init__()
函數時還需要使用build()
函數
call()
: call()
函數就很簡單了,即當其被調用時會被執行。
下面附上這幾個函數的文檔,就不做詳細介紹了,有興趣可以自己看看: