tensorflow2.0中Layer的__init__(),build(), call()函數

最近在實驗中,需要用到tensorflow建立一個簡單的模型,但鑑於部分要求比較苛刻,不能直接使用其內置的layer,因此需要自定義一個layer類,這便涉及到了對__init__(), build(), call()這三個函數的理解

先看官方手冊中使用了Layer中的這三個關鍵函數的一個簡單的實例:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_variable("kernel",
                                    shape=[int(input_shape[-1]),
                                           self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

layer = MyDenseLayer(10)

從直觀上理解,似乎__init__()build()函數都在對Layer進行初始化,都初始化了一些成員函數,而call()函數則是在該layer被調用時執行。

顯然,這三個函數都是從tf.keras.layers.Layer處繼承而來的,那麼不妨看一下官方對這幾個函數作何解釋。
下圖爲tf.keras.layers.Layer的官方文檔
在這裏插入圖片描述
簡單翻譯,就是說官方推薦凡是tf.keras.layers.Layer的派生類都要實現__init__()build(), call()這三個方法
__init__():保存成員變量的設置
build():在call()函數第一次執行時會被調用一次,這時候可以知道輸入數據的shape。返回去看一看,果然是__init__()函數中只初始化了輸出數據的shape,而輸入數據的shape需要在build()函數中動態獲取,這也解釋了爲什麼在有__init__()函數時還需要使用build()函數
call()call()函數就很簡單了,即當其被調用時會被執行。

下面附上這幾個函數的文檔,就不做詳細介紹了,有興趣可以自己看看:
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章